aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py4
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc27
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc27
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc64
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py6
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc58
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h28
-rw-r--r--tensorflow/compiler/xla/layout_util.cc6
-rw-r--r--tensorflow/compiler/xla/layout_util.h4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py25
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc27
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc21
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_options.h33
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto9
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h32
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc27
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc108
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc190
-rw-r--r--tensorflow/compiler/xla/shape_util.cc2
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc50
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc12
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt14
-rw-r--r--tensorflow/contrib/cmake/README.md345
-rw-r--r--tensorflow/contrib/distribute/python/BUILD18
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py16
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py2
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py5
-rw-r--r--tensorflow/contrib/distribute/python/moving_averages_test.py141
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py54
-rw-r--r--tensorflow/contrib/feature_column/BUILD21
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py72
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py280
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py912
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc156
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc15
-rw-r--r--tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.pngbin0 -> 18946 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.pngbin0 -> 21380 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md21
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md18
-rw-r--r--tensorflow/contrib/lite/java/BUILD95
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl5
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java20
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java46
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java14
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD16
-rw-r--r--tensorflow/contrib/lite/kernels/internal/legacy_types.h (renamed from tensorflow/compiler/xla/service/gpu/gpu_options.cc)18
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h7
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h5
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc11
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs8
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h162
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc93
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc140
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc83
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc34
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
-rw-r--r--tensorflow/contrib/optimizer_v2/BUILD11
-rw-r--r--tensorflow/contrib/optimizer_v2/adadelta.py75
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py79
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad_test.py3
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py129
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py68
-rw-r--r--tensorflow/contrib/optimizer_v2/gradient_descent.py40
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum.py69
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py1205
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop.py154
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop_test.py7
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py2
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py40
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py27
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py28
-rw-r--r--tensorflow/contrib/stateless/BUILD5
-rw-r--r--tensorflow/contrib/stateless/__init__.py9
-rw-r--r--tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py154
-rw-r--r--tensorflow/contrib/stateless/python/stateless_ops.py214
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto6
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt46
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc35
-rw-r--r--tensorflow/core/common_runtime/constant_folding.h4
-rw-r--r--tensorflow/core/common_runtime/executor.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc5
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc9
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h5
-rw-r--r--tensorflow/core/graph/node_builder.cc8
-rw-r--r--tensorflow/core/grappler/op_types.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD10
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc116
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer.h21
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc75
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h15
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc162
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc76
-rw-r--r--tensorflow/core/kernels/data/BUILD14
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc47
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h20
-rw-r--r--tensorflow/core/kernels/data/dataset_utils_test.cc46
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc196
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc62
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc27
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc79
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc26
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc12
-rw-r--r--tensorflow/core/kernels/random_op.cc34
-rw-r--r--tensorflow/core/kernels/relu_op.cc153
-rw-r--r--tensorflow/core/kernels/relu_op.h61
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h30
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc155
-rw-r--r--tensorflow/core/kernels/unique_op.cc15
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt292
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
-rw-r--r--tensorflow/core/ops/nn_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt127
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc3
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc53
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto4
-rw-r--r--tensorflow/go/op/wrappers.go834
-rw-r--r--tensorflow/python/BUILD13
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py8
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py34
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py19
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py31
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/random_dataset_test.py45
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py21
-rw-r--r--tensorflow/python/data/experimental/ops/random_ops.py21
-rw-r--r--tensorflow/python/data/experimental/ops/shuffle_ops.py21
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py95
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py25
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py29
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py22
-rw-r--r--tensorflow/python/data/util/BUILD1
-rw-r--r--tensorflow/python/data/util/random_seed.py5
-rw-r--r--tensorflow/python/data/util/random_seed_test.py13
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/function.py179
-rw-r--r--tensorflow/python/eager/function_test.py55
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc2
-rw-r--r--tensorflow/python/feature_column/feature_column.py53
-rw-r--r--tensorflow/python/framework/op_def_library.py3
-rw-r--r--tensorflow/python/framework/ops.py28
-rw-r--r--tensorflow/python/framework/test_util.py15
-rwxr-xr-xtensorflow/python/keras/BUILD155
-rw-r--r--tensorflow/python/keras/activations.py5
-rw-r--r--tensorflow/python/keras/activations_test.py10
-rw-r--r--tensorflow/python/keras/backend.py83
-rw-r--r--tensorflow/python/keras/backend_test.py44
-rw-r--r--tensorflow/python/keras/callbacks.py4
-rw-r--r--tensorflow/python/keras/engine/network.py9
-rw-r--r--tensorflow/python/keras/engine/training.py6
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py4
-rw-r--r--tensorflow/python/keras/engine/training_test.py4
-rw-r--r--tensorflow/python/keras/layers/convolutional.py177
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py31
-rw-r--r--tensorflow/python/keras/layers/pooling.py185
-rw-r--r--tensorflow/python/keras/layers/pooling_test.py30
-rw-r--r--tensorflow/python/keras/layers/wrappers.py3
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta.py116
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta_test.py166
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad.py119
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad_test.py276
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam.py203
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam_test.py333
-rw-r--r--tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py761
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2.py1349
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py277
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop.py239
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop_test.py444
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd.py170
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd_test.py759
-rw-r--r--tensorflow/python/keras/testing_utils.py5
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py45
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py17
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py26
-rw-r--r--tensorflow/python/keras/utils/np_utils.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/benchmark_test.py2
-rw-r--r--tensorflow/python/kernel_tests/bincount_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py50
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py11
-rw-r--r--tensorflow/python/kernel_tests/determinant_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_exponential_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_logarithm_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py120
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py5
-rw-r--r--tensorflow/python/ops/array_ops.py9
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py48
-rw-r--r--tensorflow/python/ops/control_flow_ops_benchmark.py122
-rw-r--r--tensorflow/python/ops/custom_gradient.py44
-rw-r--r--tensorflow/python/ops/gradients_impl.py30
-rw-r--r--tensorflow/python/ops/image_ops_test.py62
-rw-r--r--tensorflow/python/ops/nn_grad.py15
-rw-r--r--tensorflow/python/ops/nn_ops.py3
-rw-r--r--tensorflow/python/ops/parsing_ops.py13
-rw-r--r--tensorflow/python/ops/while_v2.py3
-rw-r--r--tensorflow/python/platform/benchmark.py14
-rwxr-xr-xtensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/python/training/monitored_session.py8
-rw-r--r--tensorflow/python/training/moving_averages.py49
-rw-r--r--tensorflow/python/util/protobuf/compare.py18
-rw-r--r--tensorflow/python/util/util.cc8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.android5
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh2
-rw-r--r--tensorflow/tools/pip_package/setup.py15
-rwxr-xr-xtensorflow/workspace.bzl37
-rw-r--r--third_party/jpeg/BUILD2
-rw-r--r--third_party/jpeg/BUILD.bazel (renamed from third_party/jpeg/jpeg.BUILD)11
-rw-r--r--third_party/jpeg/BUILD.system (renamed from third_party/systemlibs/jpeg.BUILD)0
-rw-r--r--third_party/jpeg/jpeg_helpers.BUILD.bazel1
-rw-r--r--third_party/jpeg/workspace.bzl16
-rw-r--r--third_party/nasm/BUILD1
-rw-r--r--third_party/nasm/BUILD.bazel (renamed from third_party/nasm.BUILD)12
-rw-r--r--third_party/nasm/BUILD.system (renamed from third_party/systemlibs/nasm.BUILD)0
-rw-r--r--third_party/nasm/workspace.bzl17
328 files changed, 13386 insertions, 4502 deletions
diff --git a/configure.py b/configure.py
index 65b4622995..89dc79b6b6 100644
--- a/configure.py
+++ b/configure.py
@@ -383,7 +383,9 @@ def set_build_var(environ_cp,
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
environ_cp[var_name] = var
if var == '1':
- write_to_bazelrc('build --define %s=true' % option_name)
+ write_to_bazelrc(
+ 'build:%s --define %s=true' % (bazel_config_name, option_name))
+ write_to_bazelrc('build --config=%s' % bazel_config_name)
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index b587e63227..9d2208d84d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -411,6 +411,7 @@ tf_cc_test(
srcs = ["gradients/nn_grad_test.cc"],
deps = [
":cc_ops",
+ ":cc_ops_internal",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 588e96cb19..2a32a2ed6f 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
+Status LeakyReluGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper);
+
+Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper);
+
Status EluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index aa72cf7ba2..f5a09e09dc 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
+#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NNGradTest, LeakyReluGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ auto y = ops::internal::LeakyRelu(scope_, x);
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ RunTest(x, x_init_value, y, shape);
+}
+
+TEST_F(NNGradTest, LeakyReluGradGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2});
+ Tensor features = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ auto y = ops::internal::LeakyReluGrad(scope_, x, features);
+ RunTest(x, x_init_value, y, shape);
+}
+
TEST_F(NNGradTest, EluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 28e09d7b79..0362682bd6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ Graph* g = body->graph;
- // Check if the graph has Switch or Merge node before optimizing the graph.
+ // Check if the graph has Switch or Merge node.
bool has_switch_or_merge = false;
for (Node* n : body->graph->nodes()) {
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
@@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction(
// in function body. We still need to rewrite those functions and modify
// corresponding nodes.
- // Call graph optimizer. The most important optimization we need is constant
- // folding, which will replace ops like Shape/BroadcastGradientArgs with
- // constant shape input. Without this optimization, those ops might become
- // dynamic input for then/else body function and XLA will complain that input
- // is not compile time constant. We enable function inlining as well, because
- // otherwise we won't be able to infer shape for any node depending on
- // function call nodes.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_opt_", func_name),
- *body->graph, fld);
- }
- // Optimizer accepts std::unique_ptr<Graph>* as input and might change
- // underlying pointer, thus we create a new Graph and copy from body->graph.
- std::unique_ptr<Graph> optimized_graph(new Graph(fld));
- CopyGraph(*body->graph, optimized_graph.get());
- OptimizerOptions opts;
- opts.set_opt_level(OptimizerOptions::L0);
- opts.set_do_function_inlining(true);
- opts.set_do_constant_folding(true);
- GraphOptimizer optimizer(opts);
- auto cf_consider_fn = [](const Node* n) {
- // Skip SymbolicGradient op when doing constant folding.
- // Enabling SymbolicGradient op in constant folding requires
- // flr->device() to be non-null, and here we have not constructed
- // proper Device object yet (it will be constructed in XlaCompiler).
- return n->type_string() != FunctionLibraryDefinition::kGradientOp;
- };
- optimizer.Optimize(flr, flr->env(),
- /*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
- cf_consider_fn);
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_opt_", func_name),
- *optimized_graph, fld);
- }
- // Some inlined functions might have Switch/Merge nodes.
- for (Node* n : optimized_graph->nodes()) {
- if (n->type_string() == "Switch" || n->type_string() == "Merge") {
- has_switch_or_merge = true;
- break;
- }
- }
-
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : optimized_graph->nodes()) {
+ for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction(
// pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- optimized_graph.get(), n, fld, associated_function, new_name));
+ g, n, fld, associated_function, new_name));
}
}
}
@@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction(
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *optimized_graph, fld);
+ *g, fld);
}
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *optimized_graph, fld);
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
+ fld);
}
}
if (*modified) {
// Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
- &functionalized_fdef));
+ TF_RETURN_IF_ERROR(
+ GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 3d81ae9eb8..f210bfbd88 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -88,20 +88,30 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
- xla::Shape xla_shape =
- xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes());
+ // The argmax function expects row-major layout.
+ xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
+ xla::S64, output_shape.dim_sizes());
+ std::vector<xla::Shape> arg_shapes;
+ for (const xla::XlaOp& arg : args) {
+ auto shape_status = b.GetShape(arg);
+ OP_REQUIRES_OK(ctx, shape_status.status());
+ xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
+ *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout(
+ xla::ShapeUtil::Rank(arg_shape));
+ arg_shapes.push_back(std::move(arg_shape));
+ }
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_1d.cc.
xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
- output =
- xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape);
+ output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
+ xla_shape, arg_shapes);
break;
case 2:
- output =
- xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape);
+ output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
+ xla_shape, arg_shapes);
break;
default:
OP_REQUIRES(ctx, false,
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 8102faad28..8eee5b1299 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel {
std::vector<int64> window_dimensions;
std::vector<int64> window_strides;
+ std::vector<int64> base_dilations;
+ std::vector<int64> window_dilations;
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
"window_dimensions", &window_dimensions));
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
&window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations",
+ &base_dilations));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dilations", &window_dilations));
const int rank = input_shape.dims();
OP_REQUIRES(context, rank == window_dimensions.size(),
@@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel {
"The size of window_strides must be equal to the input "
"rank (",
window_strides.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == base_dilations.size(),
+ errors::InvalidArgument(
+ "The size of base_dilations must be equal to the input "
+ "rank (",
+ base_dilations.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_dilations.size(),
+ errors::InvalidArgument(
+ "The size of window_dilations must be equal to the input "
+ "rank (",
+ window_dilations.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel {
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), *reducer.computation,
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
context->SetOutput(0, output);
}
@@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("XlaReduceWindow")
.CompileTimeConstInput("window_dimensions")
.CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("base_dilations")
+ .CompileTimeConstInput("window_dilations")
.CompileTimeConstInput("padding"),
ReduceWindowOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index ab094d7dd1..57afd608de 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel {
}
auto output = xla::ReduceWindowWithGeneralPadding(
XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
- *reducer, window_dims, window_strides, padding);
+ *reducer, window_dims, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding);
output =
XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0));
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 557911553d..bd2c0a5ee8 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -283,6 +283,8 @@ REGISTER_OP("XlaReduceWindow")
.Input("init_value: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
+ .Input("base_dilations: Tindices")
+ .Input("window_dilations: Tindices")
.Input("padding: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index bc7924c371..5e86b5d8ec 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -320,6 +320,8 @@ def reduce_window(operand,
reducer,
window_dimensions,
window_strides=None,
+ base_dilations=None,
+ window_dilations=None,
padding=None,
name=None):
"""Wraps the XLA ReduceWindow operator.
@@ -343,12 +345,16 @@ def reduce_window(operand,
A tensor that represents the output of the reduce_window operator.
"""
window_strides = window_strides or [1] * len(window_dimensions)
+ base_dilations = base_dilations or [1] * len(window_dimensions)
+ window_dilations = window_dilations or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
return gen_xla_ops.xla_reduce_window(
input=operand,
init_value=init,
window_dimensions=window_dimensions,
window_strides=window_strides,
+ base_dilations=base_dilations,
+ window_dilations=window_dilations,
padding=padding,
computation=reducer,
name=name)
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index d196252db1..e7cf9ae363 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1279,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
});
}
-XlaOp XlaBuilder::CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands,
- const Shape& shape, const string& opaque) {
+XlaOp XlaBuilder::CustomCall(
+ const string& call_target_name, absl::Span<const XlaOp> operands,
+ const Shape& shape, const string& opaque,
+ absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1293,6 +1294,31 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
instr.set_custom_call_opaque(opaque);
+ if (operand_shapes_with_layout.has_value()) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument(
+ "Result shape must have layout for custom call with constrained "
+ "layout.");
+ }
+ if (operands.size() != operand_shapes_with_layout->size()) {
+ return InvalidArgument(
+ "Must specify a shape with layout for each operand for custom call "
+ "with constrained layout; given %d shapes, expected %d",
+ operand_shapes_with_layout->size(), operands.size());
+ }
+ instr.set_constrain_layout(true);
+ int64 operand_num = 0;
+ for (const Shape& operand_shape : *operand_shapes_with_layout) {
+ if (!LayoutUtil::HasLayout(operand_shape)) {
+ return InvalidArgument(
+ "No layout specified for operand %d for custom call with "
+ "constrained layout.",
+ operand_num);
+ }
+ *instr.add_operand_shapes_with_layout() = operand_shape;
+ ++operand_num;
+ }
+ }
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -1789,9 +1815,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
window_strides, padding);
- return ReduceWindowWithGeneralPadding(operand, init_value, computation,
- window_dimensions, window_strides,
- padding_values);
+ return ReduceWindowWithGeneralPadding(
+ operand, init_value, computation, window_dimensions, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
});
}
@@ -1800,6 +1826,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1810,7 +1838,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
MakeWindow(window_dimensions, window_strides, padding,
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
+ /*lhs_dilation=*/base_dilations,
+ /*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
@@ -2687,7 +2716,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque) {
- return builder->CustomCall(call_target_name, operands, shape, opaque);
+ return builder->CustomCall(call_target_name, operands, shape, opaque,
+ /*operand_shapes_with_layout=*/absl::nullopt);
+}
+
+XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque,
+ operand_shapes_with_layout);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
@@ -2800,10 +2838,12 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
- padding);
+ base_dilations, window_dilations, padding);
}
XlaOp CrossReplicaSum(const XlaOp& operand,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cd0d5ca5d3..9ceede7a79 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -577,9 +577,10 @@ class XlaBuilder {
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
- XlaOp CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- const string& opaque);
+ XlaOp CustomCall(
+ const string& call_target_name, absl::Span<const XlaOp> operands,
+ const Shape& shape_with_layout, const string& opaque,
+ absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -671,6 +672,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
@@ -1195,6 +1198,10 @@ class XlaBuilder {
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque);
+ friend XlaOp CustomCallWithLayout(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
+ absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1245,6 +1252,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
friend XlaOp CrossReplicaSum(const XlaOp& operand,
absl::Span<const ReplicaGroup> replica_groups);
@@ -1728,6 +1737,17 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque = "");
+// Overload which constructs a custom call with fixed layouts. The operands will
+// have the layouts specified by |operand_shapes_with_layout| when provided to
+// external code, and the external code is expected to produce a result with the
+// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
+// and |operand_shapes_with_layout| must have layouts.
+XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands,
+ const Shape& shape_with_layout,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const string& opaque = "");
+
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
@@ -1818,6 +1838,8 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index d310335618..3c8db9aa45 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -65,6 +65,12 @@ void SetDefaultLayoutToContainer(
return layout;
}
+/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
+ std::vector<int64> layout(rank);
+ std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
+ return MakeLayout(layout);
+}
+
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
absl::Span<const int64> major_to_minor) {
Layout layout;
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index b78883c2d8..af032b1cae 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -40,6 +40,10 @@ class LayoutUtil {
static Layout MakeLayoutFromMajorToMinor(
absl::Span<const int64> major_to_minor);
+ // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major
+ // dimensions.
+ static Layout MakeDescendingLayout(int64 rank);
+
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
static Layout MakeSparseLayout(int64 max_sparse_elements);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd5fd33029..ffa336f304 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -532,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
}
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 2166bb6721..43332e0abd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -278,6 +278,8 @@ class LocalComputationBuilder {
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index bb303c5678..f8197488fb 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -995,7 +995,30 @@ class ComputationBuilder(object):
window_strides)
return self._client.ReduceWindowWithGeneralPadding(
operand, init_value, computation_to_apply.c_local_computation,
- window_dimensions, window_strides, pads)
+ window_dimensions, window_strides, (), (), pads)
+
+ def ReduceWindowWithGeneralPadding(
+ self, operand, init_value, computation_to_apply, window_dimensions,
+ window_strides, base_dilations, window_dilations, padding):
+ """Enqueues a windowed reduction operation onto the computation.
+
+ Args:
+ operand: reduction operand (LocalOp).
+ init_value: reduction initial value (LocalOp).
+ computation_to_apply: a binary reduction function (Computation).
+ window_dimensions: dimensions of window (sequence of integers).
+ window_strides: strides for window (sequence of integers).
+ base_dilations: dilations for the base (sequence of integers).
+ window_dilations: dilations for window (sequence of integers).
+ padding: length-N array-like of pairs of integers of (low, high) padding.
+
+ Returns:
+ A LocalOp representing the added ReduceWindow op.
+ """
+ return self._client.ReduceWindowWithGeneralPadding(
+ operand, init_value, computation_to_apply.c_local_computation,
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 75dae7a714..86d9dbea90 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2057,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return Status::OK();
}
+ // Bail on dilation.
+ if (window_util::HasDilation(window)) {
+ VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
+ return Status::OK();
+ }
+
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
<< (convert != nullptr
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a70abb117a..b2abdb39a5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -688,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ input_index[i] = NSWSub(
+ NSWAdd(strided_index,
+ NSWMul(window_index[i],
+ b_.getInt64(window.dimensions(i).window_dilation()))),
+ b_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())),
+ b_.getInt64(0));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = dilation_condition;
+ } else {
+ in_bounds_condition = And(in_bounds_condition, dilation_condition);
+ }
+
+ // Apply base dilation to the index.
+ input_index[i] =
+ SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
@@ -728,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
/*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32, F16}));
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(reduce_window->window())) {
- return Unimplemented(
- "Dilation for ReduceWindow is not implemented on CPU.");
- }
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 522e9f5948..350fd32537 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -404,6 +404,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
],
)
@@ -780,7 +781,6 @@ cc_library(
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
deps = [
- ":gpu_options",
":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
@@ -882,16 +882,6 @@ cc_library(
)
cc_library(
- name = "gpu_options",
- srcs = ["gpu_options.cc"],
- hdrs = ["gpu_options.h"],
- deps = [
- "//tensorflow/compiler/xla/service:hlo_module_config",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-cc_library(
name = "stream_executor_util",
srcs = ["stream_executor_util.cc"],
hdrs = ["stream_executor_util.h"],
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 7125673887..6d4a72038f 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -145,7 +145,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
-StatusOr<std::tuple<int64, bool, int64>>
+StatusOr<CudnnConvolutionAlgorithmPicker::AutotuneResult>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
HloCustomCallInstruction* instr) {
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -316,9 +316,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< AlgorithmToString(best_result.algorithm()) << ", takes "
<< best_result.elapsed_time_in_ms() << "ms, and uses "
<< best_result_bytes_used << "B of scratch memory.";
- return std::make_tuple(best_result.algorithm().algo_id(),
- best_result.algorithm().tensor_ops_enabled(),
- best_result_bytes_used);
+ return AutotuneResult{best_result.algorithm().algo_id(),
+ best_result.algorithm().tensor_ops_enabled(),
+ best_result_bytes_used,
+ absl::Milliseconds(best_result.elapsed_time_in_ms())};
}
return InternalError(
@@ -331,41 +332,37 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ StatusOr<AutotuneResult> best_algo_or =
PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
-
- if (!alg_scratch_and_tc.ok()) {
- LOG(ERROR) << alg_scratch_and_tc.status();
+ if (!best_algo_or.ok()) {
+ LOG(ERROR) << best_algo_or.status();
return false;
}
- int64 algorithm;
- bool tensor_ops_enabled;
- int64 scratch_bytes;
-
- std::tie(algorithm, tensor_ops_enabled, scratch_bytes) =
- alg_scratch_and_tc.ConsumeValueOrDie();
-
- VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
- << NumBytesToString(scratch_bytes)
+ auto best_algo = std::move(best_algo_or).ValueOrDie();
+ VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm
+ << " and " << NumBytesToString(best_algo.scratch_bytes)
<< " of scratch memory: " << instr->ToString()
- << " tensor_ops_enabled: " << tensor_ops_enabled;
+ << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled;
// Replace instr with a new CustomCall which has the correct algorithm, and
// whose output shape has the appropriate amount of scratch memory.
HloComputation* computation = instr->parent();
- Shape new_call_shape =
- ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
- ShapeUtil::MakeShape(U8, {scratch_bytes})});
+ Shape new_call_shape = ShapeUtil::MakeTupleShape(
+ {instr->shape().tuple_shapes(0),
+ ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})});
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
instr->backend_config<CudnnConvBackendConfig>());
- backend_config.set_algorithm(algorithm);
- backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
+ backend_config.set_algorithm(best_algo.algorithm);
+ backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled);
HloInstruction* new_call = computation->AddInstruction(
instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+ VLOG(1) << "Replacing convolution " << instr->ToString() << " with "
+ << new_call->ToString();
+
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index aeda2fc7f8..136c32210a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -47,10 +48,16 @@ class CudnnConvolutionAlgorithmPicker : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
private:
+ struct AutotuneResult {
+ int64 algorithm;
+ bool tensor_ops_enabled;
+ int64 scratch_bytes;
+ absl::Duration runtime;
+ };
+
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- HloCustomCallInstruction* instr);
+ StatusOr<AutotuneResult> PickBestAlgorithm(HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index ef29237301..437d25727e 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -525,6 +525,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
TF_RETURN_IF_ERROR(
custom_call->set_backend_config(GetDefaultBackendConfig()));
+ VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
+ << custom_call->ToString();
+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
index 3761c19cfc..d508cbc2e1 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
@@ -234,7 +234,8 @@ StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
config.set_side_input_scale(alpha_side_input);
TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
- VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name();
+ VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
+ << new_conv->ToString();
return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
new_conv, 0);
}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c1aaa4bf04..6dcdaf1cfe 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
const Window& window = hlo->window();
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
- return Unimplemented(
- "Dilation for reduce-window not implemented on GPU. "
- "See b/31410564.");
- }
-
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
@@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(
+ NSWAdd(stridden_index,
+ NSWMul(window_index[i],
+ index_typed_const(
+ window.dimensions(i).window_dilation()))),
+ index_typed_const(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation())),
+ index_typed_const(0));
+ in_bounds = And(in_bounds, dilation_condition);
+
+ // Apply base dilation to the index.
input_index[i] =
- NSWSub(NSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ SDiv(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 74352f26aa..8c9a8adc61 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -125,14 +124,8 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
DataLayout input;
FilterLayout filter;
DataLayout output;
- if (ConvUseLayoutHeuristic(instr->GetModule()->config())) {
- std::tie(input, filter, output) =
- HeuristicLayoutAssignment(instr, stream_executor_);
- } else {
- input = DataLayout::kBatchDepthYX;
- filter = FilterLayout::kOutputInputYX;
- output = DataLayout::kBatchDepthYX;
- }
+ std::tie(input, filter, output) =
+ HeuristicLayoutAssignment(instr, stream_executor_);
TF_ASSIGN_OR_RETURN(
std::tie(*input_shape->mutable_layout(),
@@ -220,16 +213,6 @@ Status GpuLayoutAssignment::AddBackendConstraints(
return Status::OK();
}
-bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout(
- const HloInstruction* instruction) {
- // - Inputs to cudnn batchnorm custom calls don't need the major-first layout
- // (i.e. {n, n-1, ...0}) -- we can handle any layout.
- // - Inputs to cudnn convolution require custom layouts handled in
- // AddBackendConstraints.
- return !IsCustomCallToDnnBatchNorm(*instruction) &&
- !IsCustomCallToDnnConvolution(*instruction);
-}
-
Status GpuLayoutAssignment::PropagateOperandConstraint(
const OperandLayoutConstraint& layout_constraint,
LayoutConstraints* constraints) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 4ba7989e9c..6a48e55fd2 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -46,8 +46,6 @@ class GpuLayoutAssignment : public LayoutAssignment {
Status PropagateBufferConstraint(
const BufferLayoutConstraint& buffer_constraint,
LayoutConstraints* constraints) override;
- bool CustomCallRequiresMajorFirstLayout(
- const HloInstruction* instruction) override;
private:
Status AddBackendConstraintsToDnnConvCustomCall(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h
deleted file mode 100644
index 498d4a9495..0000000000
--- a/tensorflow/compiler/xla/service/gpu/gpu_options.h
+++ /dev/null
@@ -1,33 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
-
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-
-// Helper functions for querying options that are specific to the GPU backend.
-
-namespace xla {
-namespace gpu {
-
-// Returns true if we should use heuristics to assign convolution layouts, as
-// opposed to always assigning NCHW.
-bool ConvUseLayoutHeuristic(const HloModuleConfig& config);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 1ea26ddd5b..a0eb9e6ddc 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 56
+// Next ID: 58
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -184,6 +184,13 @@ message HloInstructionProto {
// Sharding for kDomain instructions.
xla.OpSharding domain_entry_sharding = 54;
xla.OpSharding domain_exit_sharding = 55;
+
+ // For custom call this indicates that the layouts are constrained. If
+ // constrain_layout is true then the 'shape' field must contain a layout, and
+ // 'operand_shapes_with_layout' must contain a shape with layout for each
+ // operand.
+ bool constrain_layout = 56;
+ repeated Shape operand_shapes_with_layout = 57;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 6ca1255ede..c6d02f9f67 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -42,18 +42,19 @@ namespace xla {
return std::move(domain_map);
}
-bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
- HloInstruction* instruction2) const {
+bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
+ const HloInstruction* instruction2) const {
int64 domain_id1 = GetDomainId(instruction1);
int64 domain_id2 = GetDomainId(instruction2);
return domain_id1 >= 0 && domain_id1 == domain_id2;
}
-int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
+int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
-int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+int64 HloDomainMap::GetDomainMetadataId(
+ const HloInstruction* instruction) const {
return FindOrDie(domain_metadata_id_, instruction);
}
@@ -200,7 +201,8 @@ StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
return std::move(domain);
}
-bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
+bool HloDomainMap::IsDomainInstruction(
+ const HloInstruction* instruction) const {
if (instruction->opcode() != HloOpcode::kDomain) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index c8d581b746..bce7d1aa7c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -58,21 +58,21 @@ class HloDomainMap {
}
// Checks whether two instructions are within the same domain.
- bool InSameDomain(HloInstruction* instruction1,
- HloInstruction* instruction2) const;
+ bool InSameDomain(const HloInstruction* instruction1,
+ const HloInstruction* instruction2) const;
// Checks whether instruction is a kDomain instruction of the kind we are
// currently processing.
- bool IsDomainInstruction(HloInstruction* instruction) const;
+ bool IsDomainInstruction(const HloInstruction* instruction) const;
// Retrieves the domain identifier of the instruction, or -1 in case
// instruction is not found within any domain.
- int64 GetDomainId(HloInstruction* instruction) const;
+ int64 GetDomainId(const HloInstruction* instruction) const;
// Returns the unique id of the domain metadata for the domain the given
// instruction belongs to. The given instruction must not be a kDomain
// instruction since each domain instruction is associated with 2 domains.
- int64 GetDomainMetadataId(HloInstruction* instruction) const;
+ int64 GetDomainMetadataId(const HloInstruction* instruction) const;
private:
// Map used for representing instruction ordering, i.e.
@@ -119,8 +119,8 @@ class HloDomainMap {
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
- absl::flat_hash_map<HloInstruction*, int64> instruction_to_domain_;
- absl::flat_hash_map<HloInstruction*, int64> domain_metadata_id_;
+ absl::flat_hash_map<const HloInstruction*, int64> instruction_to_domain_;
+ absl::flat_hash_map<const HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index cee11a8a21..608a42bb60 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1463,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
+TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) {
+ HloComputation::Builder b(TestName());
+
+ // arg:
+ // f32[3,3] {
+ // { 1, 2, 3 },
+ // { 5, 6, 7 },
+ // { 9, 10, 11 },
+ // }
+ auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
+ arg_array->FillUnique(1.0f);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
+
+ HloInstruction* arg_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
+
+ HloComputation::Builder max_computation("max");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ max_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
+ auto max_func = module().AddEmbeddedComputation(max_computation.Build());
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(2);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, arg_instruction, init_value, window, max_func));
+
+ module().AddEntryComputation(b.Build());
+
+ Literal result = Evaluate();
+
+ auto expected = LiteralUtil::CreateR2<float>({{11}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
HloComputation::Builder b(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b2d12c94b8..a450dc6ff5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> base_index(rank);
bool out_of_bound = false;
for (int64 i = 0; i < rank; ++i) {
- base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
+ base_index[i] =
+ window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] * window.dimensions(i).window_dilation() -
+ window.dimensions(i).padding_low();
+ // We are not in the base area if the dilation placed us out of bounds.
+ if (base_index[i] % window.dimensions(i).base_dilation() != 0) {
+ out_of_bound = true;
+ break;
+ }
+ // Apply the dilation to the base area.
+ base_index[i] /= window.dimensions(i).base_dilation();
if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
out_of_bound = true;
break;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 2f6db7cd7c..5c3908a9a4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -396,9 +396,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
operands(1), operands(2), computations(1));
break;
case HloOpcode::kCustomCall:
- instruction = CreateCustomCall(proto.shape(), all_operands(),
- proto.custom_call_target(),
- proto.custom_call_opaque());
+ if (proto.constrain_layout()) {
+ // A proto RepeatedPtrField cannot be converted to a Span (it is a
+ // vector of pointers essentially) so create a vector of shapes to pass
+ // in.
+ std::vector<Shape> operand_shapes;
+ for (const Shape& shape : proto.operand_shapes_with_layout()) {
+ operand_shapes.push_back(shape);
+ }
+ instruction = CreateCustomCall(
+ proto.shape(), all_operands(), proto.custom_call_target(),
+ operand_shapes, proto.custom_call_opaque());
+ } else {
+ instruction = CreateCustomCall(proto.shape(), all_operands(),
+ proto.custom_call_target(),
+ proto.custom_call_opaque());
+ }
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
@@ -1142,6 +1155,15 @@ bool HloInstruction::HasSideEffect() const {
shape, operands, custom_call_target, opaque);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ absl::string_view opaque) {
+ return absl::make_unique<HloCustomCallInstruction>(
+ shape, operands, custom_call_target, opaque, operand_shapes_with_layout);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
absl::Span<HloInstruction* const> elements) {
std::vector<Shape> element_shapes;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 374862c4b6..44f776ebac 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -734,6 +734,16 @@ class HloInstruction {
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, absl::string_view opaque = "");
+ // Overload which constrains the layouts of the operand and result. 'shape'
+ // and 'operand_shapes_with_layout' must have layouts.
+ // 'operand_shapes_with_layout' must have a compatible element for each
+ // operand.
+ static std::unique_ptr<HloInstruction> CreateCustomCall(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ absl::string_view opaque = "");
+
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 152d8eacdb..2ec233eaec 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1825,7 +1825,24 @@ HloCustomCallInstruction::HloCustomCallInstruction(
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
opaque_(opaque.begin(), opaque.end()),
- feature_group_count_(1) {
+ feature_group_count_(1),
+ layout_constrained_(false) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloCustomCallInstruction::HloCustomCallInstruction(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target, absl::string_view opaque,
+ absl::Span<const Shape> operand_shapes_with_layout)
+ : HloInstruction(HloOpcode::kCustomCall, shape),
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
+ feature_group_count_(1),
+ layout_constrained_(true),
+ operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
+ operand_shapes_with_layout.end()) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1843,6 +1860,12 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
proto.set_custom_call_target(custom_call_target_);
proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
+ if (layout_constrained()) {
+ proto.set_constrain_layout(true);
+ for (const Shape& shape : operand_shapes_with_layout_) {
+ *proto.add_operand_shapes_with_layout() = shape;
+ }
+ }
return proto;
}
@@ -1870,6 +1893,14 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
if (!opaque_.empty()) {
extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
}
+ if (layout_constrained()) {
+ std::vector<string> shape_strings;
+ for (const Shape& shape : operand_shapes_with_layout_) {
+ shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
+ }
+ extra.push_back(StrCat("operand_layout_constraints={",
+ StrJoin(shape_strings, ", "), "}"));
+ }
return extra;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e169604072..4c5fc759a3 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1053,10 +1053,19 @@ class HloSelectAndScatterInstruction : public HloInstruction {
class HloCustomCallInstruction : public HloInstruction {
public:
- explicit HloCustomCallInstruction(const Shape& shape,
- absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target,
- absl::string_view opaque);
+ HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
+
+ // Constructor for a custom call with constrained layout. 'shape' and
+ // 'operands_with_layout' must all have layouts.
+ HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::string_view opaque,
+ absl::Span<const Shape> operand_shapes_with_layout);
+
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1085,6 +1094,16 @@ class HloCustomCallInstruction : public HloInstruction {
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns whether the result and operand layouts are constrained.
+ bool layout_constrained() const { return layout_constrained_; }
+
+ // Returns the shapes (with layout) of the operands. CHECKs if this custom
+ // call does not have constrained layouts.
+ const std::vector<Shape>& operand_shapes_with_layout() const {
+ CHECK(layout_constrained());
+ return operand_shapes_with_layout_;
+ }
+
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
@@ -1106,6 +1125,11 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
// The number of feature groups. This is used for grouped convolutions.
int64 feature_group_count_;
+ // Whether the result and operand layouts are constrained.
+ bool layout_constrained_;
+ // For layout-constrained custom calls, this vector holds the shape with
+ // layout for each operand.
+ std::vector<Shape> operand_shapes_with_layout_;
};
class HloPadInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index dd62988bcc..96f9ff6654 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -174,6 +174,7 @@ class HloParser {
kDistribution,
kDomain,
kPrecisionList,
+ kShapeList
};
struct AttrConfig {
@@ -240,6 +241,7 @@ class HloParser {
bool ParseSliceRanges(SliceRanges* result);
bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
+ bool ParseShapeList(std::vector<Shape>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -1341,6 +1343,7 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
+ optional<std::vector<Shape>> operand_layout_constraints;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
@@ -1349,12 +1352,52 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ attrs["operand_layout_constraints"] = {
+ /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
- opaque.has_value() ? *opaque : ""));
+ if (operand_layout_constraints.has_value()) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return Error(lexer_.GetLoc(),
+ "Layout must be set on layout-constrained custom call");
+ }
+ if (operands.size() != operand_layout_constraints->size()) {
+ return Error(lexer_.GetLoc(),
+ StrCat("Expected ", operands.size(),
+ " operand layout constraints, ",
+ operand_layout_constraints->size(), " given"));
+ }
+ for (int64 i = 0; i < operands.size(); ++i) {
+ const Shape& operand_shape_with_layout =
+ (*operand_layout_constraints)[i];
+ if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
+ return Error(lexer_.GetLoc(),
+ StrCat("Operand layout constraint shape ",
+ ShapeUtil::HumanStringWithLayout(
+ operand_shape_with_layout),
+ " for operand ", i, " does not have a layout"));
+ }
+ if (!ShapeUtil::Compatible(operand_shape_with_layout,
+ operands[i]->shape())) {
+ return Error(
+ lexer_.GetLoc(),
+ StrCat(
+ "Operand layout constraint shape ",
+ ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
+ " for operand ", i,
+ " is not compatible with operand shape ",
+ ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
+ }
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, operands, *custom_call_target, *operand_layout_constraints,
+ opaque.has_value() ? *opaque : ""));
+ } else {
+ instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
+ }
if (window.has_value()) {
instruction->set_window(*window);
}
@@ -2533,6 +2576,15 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kShapeList: {
+ std::vector<Shape> result;
+ if (!ParseShapeList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
@@ -2825,6 +2877,23 @@ bool HloParser::ParsePrecisionList(
parse_and_add_item);
}
+// shapelist ::= '{' shapes '}'
+// precision_elements
+// ::= /*empty*/
+// ::= shape (',' shape)*
+bool HloParser::ParseShapeList(std::vector<Shape>* result) {
+ auto parse_and_add_item = [&]() {
+ Shape shape;
+ if (!ParseShape(&shape)) {
+ return false;
+ }
+ result->push_back(std::move(shape));
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2832,23 +2901,15 @@ bool HloParser::ParsePrecisionList(
bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result) {
- if (!ParseToken(start, StrCat("expects an int64 list starting with ",
- TokKindToString(start)))) {
- return false;
- }
- if (lexer_.GetKind() == end) {
- // empty
- } else {
- do {
- tensorflow::int64 i;
- if (!ParseInt64(&i)) {
- return false;
- }
- result->push_back(i);
- } while (EatIfPresent(delim));
- }
- return ParseToken(
- end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
+ auto parse_and_add_item = [&]() {
+ tensorflow::int64 i;
+ if (!ParseInt64(&i)) {
+ return false;
+ }
+ result->push_back(i);
+ return true;
+ };
+ return ParseList(start, end, delim, parse_and_add_item);
}
bool HloParser::ParseList(const TokKind start, const TokKind end,
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 255123d331..17538c05bc 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -804,6 +804,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
)"
},
+// CustomCallWithLayoutConstraints
+{
+"CustomCallWithLayoutConstraints",
+R"(HloModule CustomCallWithLayoutConstraints
+
+ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
+}
+
+)"
+},
+// CustomCallWithLayoutConstraintsNoOperands
+{
+"CustomCallWithLayoutConstraintsNoOperands",
+R"(HloModule CustomCallWithLayoutConstraintsNoOperands
+
+ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
+}
+
+)"
+},
+// CustomCallWithLayoutConstraintsTupleShapes
+{
+"CustomCallWithLayoutConstraintsTupleShapes",
+R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
+
+ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
+ %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
+}
+
+)"
+},
});
// clang-format on
}
@@ -2069,5 +2106,35 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
+ const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
+
+ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
+}
+
+)";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "Expected 2 operand layout constraints, 1 given");
+}
+
+TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
+ const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
+
+ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
+}
+
+)";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "operand 1 is not compatible with operand shape");
+}
+
+// custom call incompatible shape.
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index fad3b14ec2..be3bee5975 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -313,8 +313,9 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
operand_dimension < ShapeUtil::Rank(operand_shape);
++operand_dimension) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
- TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension))
+ TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) &&
+ (broadcast->shape().dimensions(output_dimension) ==
+ operand_shape.dimensions(operand_dimension)))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return Status::OK();
@@ -359,7 +360,27 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) {
return CheckShape(call, call->to_apply()->root_instruction()->shape());
}
-Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<const HloCustomCallInstruction>(instruction);
+ TF_RET_CHECK(custom_call != nullptr);
+ if (custom_call->layout_constrained()) {
+ // If the layout is constrained, verify all the respective shapes have
+ // layouts and that the constrained operand shapes match the shapes of the
+ // operands.
+ TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
+ TF_RET_CHECK(custom_call->operand_count() ==
+ custom_call->operand_shapes_with_layout().size());
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ const Shape& operand_shape_with_layout =
+ custom_call->operand_shapes_with_layout()[i];
+ TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
+ operand_shape_with_layout));
+ TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
return CheckShape(slice,
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index cc4a342e9d..ad65b147c1 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints(
return Status::OK();
}
+namespace {
+
+bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ return custom_call != nullptr && custom_call->layout_constrained();
+}
+
+} // namespace
+
Status LayoutAssignment::AddMandatoryConstraints(
const ComputationLayout* computation_layout,
ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
@@ -434,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints(
// Constrain layouts of instructions which define values with pre-existing
// layouts.
for (auto* instruction : computation->instructions()) {
- Shape const* shape_with_layout = nullptr;
if (instruction->opcode() == HloOpcode::kInfeed) {
// Infeed layouts must match the layout of the original inserted
// instruction.
@@ -456,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints(
if (parameter_layout.LayoutIsSet()) {
// Parameter layouts must match the respective layout in
// ComputationLayout, if there is one.
- shape_with_layout = &parameter_layout.shape();
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ parameter_layout.shape(), instruction));
}
}
- }
- if (shape_with_layout != nullptr) {
+ } else if (IsLayoutConstrainedCustomCall(instruction)) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
TF_RETURN_IF_ERROR(
- constraints->SetInstructionLayout(*shape_with_layout, instruction));
- }
-
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kRecv) {
+ constraints->SetInstructionLayout(custom_call->shape(), custom_call));
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ custom_call->operand_shapes_with_layout()[i], custom_call, i));
+ }
+ } else if (instruction->opcode() == HloOpcode::kSend ||
+ instruction->opcode() == HloOpcode::kRecv) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id();
@@ -621,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
false_computation_layout.parameter_shape(0), instruction, 2,
/*mandatory=*/true));
- } else if (instruction->opcode() == HloOpcode::kCustomCall) {
- if (!CustomCallRequiresMajorFirstLayout(instruction)) {
- continue;
- }
- // Add constraints for kCustomCall instruction operands and instructions.
- // For now we only support major-first layouts for all inputs and outputs.
- Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
- instruction->shape().element_type(),
- AsInt64Slice(instruction->shape().dimensions()));
- TF_RETURN_IF_ERROR(
- constraints->SetInstructionLayout(result_shape, instruction));
- for (int64 i = 0; i < instruction->operand_count(); ++i) {
- const Shape& operand_shape = instruction->operand(i)->shape();
- // Opaque operands don't get a layout constraint.
- if (ShapeUtil::IsOpaque(operand_shape)) {
- continue;
- }
-
- Shape row_major_operand_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(
- operand_shape.element_type(),
- AsInt64Slice(operand_shape.dimensions()));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- row_major_operand_shape, instruction, i));
- }
}
}
// Finally set the result layout to match ComputationLayout, if there is one.
@@ -676,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call,
return Status::OK();
}
-// Custom calls have fixed input and output layouts.
-Status CheckCustomCallLayout(HloInstruction* custom_call) {
- for (const HloInstruction* operand : custom_call->operands()) {
- TF_RET_CHECK(
- ShapeUtil::IsOpaque(operand->shape()) ||
- LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+// Operands of layout-constrained custom calls must match the expected
+// constrained layouts.
+Status CheckCustomCallLayout(HloInstruction* instruction) {
+ if (IsLayoutConstrainedCustomCall(instruction)) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ custom_call->operand(i)->shape(),
+ custom_call->operand_shapes_with_layout()[i]));
+ }
}
- TF_RET_CHECK(
- ShapeUtil::IsOpaque(custom_call->shape()) ||
- LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout()));
return Status::OK();
}
@@ -932,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, instruction->to_apply())));
break;
case HloOpcode::kCustomCall:
- if (CustomCallRequiresMajorFirstLayout(instruction)) {
- TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
- }
+ TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
break;
case HloOpcode::kFusion:
TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
@@ -1554,11 +1542,11 @@ Status LayoutAssignment::CalculateComputationLayout(
Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// Clear existing layouts of the instructions. All layouts must be assigned
- // by the LayoutAssignment pass, except for those on infeeds, parameters,
- // and the computation result. The latter two are specified in
- // computation_layout, so we only need to keep the existing layouts for
- // infeeds. Clearing the layouts here avoids hiding potential bugs in the
- // layout assignment pass that may accidentally use the existing layout.
+ // by the LayoutAssignment pass, except for those on parameters, the
+ // computation result, and a couple special cases. The former two are
+ // specified in computation_layout. Clearing the layouts here avoids hiding
+ // potential bugs in the layout assignment pass that may accidentally use the
+ // existing layout.
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kBitcast) {
// bitcasts are inherently layout sensitive and so a bitcast instruction
@@ -1567,7 +1555,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
"Unexpected bitcast operation seen during layout assignment: %s.",
instruction->ToString());
}
- if (instruction->opcode() != HloOpcode::kInfeed) {
+ // Some instructions carry mandatory layouts in their shape.
+ if (instruction->opcode() != HloOpcode::kInfeed &&
+ !IsLayoutConstrainedCustomCall(instruction)) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
@@ -1802,6 +1792,18 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
}
TF_RETURN_IF_ERROR(Init());
+ // Verify computation layout is sane.
+ const HloComputation* entry = module->entry_computation();
+ TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
+ entry->num_parameters());
+ for (int64 i = 0; i < entry->num_parameters(); ++i) {
+ TF_RET_CHECK(
+ ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
+ entry->parameter_instruction(i)->shape()));
+ }
+ TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
+ entry->root_instruction()->shape()));
+
// We do two passes. The first one we pass a nullptr ComputationLayout to
// the RunOnComputation() calls (for non entry computations), and we register
// the ComputationLayout which are naturally flowing in DFS fashion to the
@@ -1873,7 +1875,6 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
- case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
@@ -1930,6 +1931,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kCopy:
+ case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kFusion:
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 2d48e12263..cb56f4cd19 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -333,19 +333,6 @@ class LayoutAssignment : public HloModulePass {
const ResultLayoutConstraint& layout_constraint,
LayoutConstraints* constraints);
- // By default LayoutAssignment ensures that inputs and outputs of CustomCalls
- // have the "major-first" layout (i.e. {n, n-1, ..., 0}).
- //
- // If this function returns true, LayoutAssignment does not set a layout for
- // the given CustomCall. It's up to the backend to set one in
- // AddBackendConstraints, if necessary.
- //
- // Precondition: instruction->opcode() == HloOpcode::kCustomCall.
- virtual bool CustomCallRequiresMajorFirstLayout(
- const HloInstruction* /*instruction*/) {
- return true;
- }
-
// Called after layouts of an instruction have been finalized to allow
// subclasses to check for platform specific assumptions.
virtual Status Verify(const HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 2c549cd872..ff6fdb5e4a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -65,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase {
FindInstruction(module, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
}
+
+ void ExpectLayoutIs(const Shape& shape,
+ absl::Span<const int64> minor_to_major) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
+ << "Expected layout " << expected << ", actual " << shape.layout();
+ }
+
+ void ExpectTupleLayoutIs(
+ const Shape& shape,
+ std::initializer_list<absl::Span<const int64>> minor_to_majors) {
+ int i = 0;
+ for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
+ EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
+ << "Expected tuple element " << i << " layout " << expected
+ << ", actual " << actual;
+ ++i;
+ }
+ }
};
TEST_F(LayoutAssignmentTest, ComputationLayout) {
@@ -1102,5 +1123,174 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0));
}
+TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
+ const char* module_str = R"(
+HloModule CustomCallNotLayoutConstrained
+
+ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
+ %p = f32[42,2,3] parameter(0)
+ ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
+}
+)";
+ // Try with a couple different layouts. In each case the custom calls operand
+ // and result layout should match that of the computation.
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::CustomCall(op::Parameter()));
+ ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
+ ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
+ }
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::CustomCall(op::Parameter()));
+ ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
+ ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
+ }
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrained
+
+ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
+ %p0 = f32[4,4] parameter(0)
+ %p1 = f32[2,3] parameter(1)
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ // The custom call should be partially encapsulated in kCopy instructions
+ // because of the layout mismatches.
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall(op::Copy(), op::Parameter())));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+ ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
+ ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedZeroOperands
+
+ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall()));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedTupleOperand
+
+ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
+ %p0 = f32[4,4] parameter(0)
+ %p1 = f32[2,3] parameter(1)
+ %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
+
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall(op::Tuple())));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+ ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedTupleResult
+
+ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
+ %p0 = f32[4,4] parameter(0)
+ ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
+}
+)";
+ // Try with a couple different layouts. In each case the custom calls operand
+ // and result layout should match that of the computation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ ExpectTupleLayoutIs(module->entry_computation()->root_instruction()->shape(),
+ {{1, 0}, {1, 0}});
+
+ const HloInstruction* custom_call =
+ FindInstruction(module.get(), "custom-call");
+ ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index d244923532..7f0201942b 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1645,7 +1645,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
}
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
- out << ShapeUtil::HumanString(shape);
+ out << ShapeUtil::HumanStringWithLayout(shape);
return out;
}
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 070b092d18..b851db14ec 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
XlaBuilder builder(TestName());
auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
- Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ PrecisionConfig precision;
+ // The left hand side of the convolution is numbers between 0 and 2304 which
+ // requires at least 11 mantissa bits and the DEFAULT precision config is
+ // allowed to round to bfloat16 which only has 7 mantissa bits.
+ precision.add_operand_precision(PrecisionConfig::HIGHEST);
+ precision.add_operand_precision(PrecisionConfig::DEFAULT);
+ Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
+ &precision);
ComputeAndCompare(&builder, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index a693fa3595..001490c6a8 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
-XLA_TEST_F(CustomCallTest,
- DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) {
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
auto module = CreateNewModule();
auto b = HloComputation::Builder(TestName());
@@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest,
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
+ auto module = CreateNewModule();
+ auto b = HloComputation::Builder(TestName());
+
+ auto input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
+ b.AddInstruction(
+ HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues"));
+
+ module->AddEntryComputation(b.Build());
+ ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
+
+ Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
+
+ // Note, the expected result is transposed! This is because the input and
+ // output layouts of the custom call differ and the called function just
+ // blindly adds one to each element.
+ Literal result = ExecuteAndTransfer(std::move(module), {&argument});
+ LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
+}
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
+ // The argument and result of the computation are set to different layouts,
+ // but the custom call is layout constrained to a fixed operand and result
+ // layout, so the correct result should be produced.
+ auto module = CreateNewModule();
+ auto b = HloComputation::Builder(TestName());
+
+ auto input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
+
+ const Shape& r2f32_dim0_major =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
+ b.AddInstruction(HloInstruction::CreateCustomCall(
+ r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major}));
+
+ module->AddEntryComputation(b.Build());
+ ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
+
+ Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
+
+ Literal result = ExecuteAndTransfer(std::move(module), {&argument});
+ LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result);
+}
+
class CustomCallClientAPITest : public ClientLibraryTestBase {};
// When using the client API, CustomCall targets can't begin with '$' -- these
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index c25ccafaf8..22fe4a2670 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
/*padding=*/padding);
CHECK(reducer == kAdd || reducer == kMax);
@@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
@@ -1369,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index f675c135f4..244683765a 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -1,6 +1,16 @@
# Minimum CMake required
cmake_minimum_required(VERSION 3.5)
+if(WIN32)
+ if(${CMAKE_VERSION} VERSION_LESS "3.8")
+ message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake.")
+ else()
+ if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64")
+ message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake.")
+ endif()
+ endif()
+endif()
+
# Project
project(tensorflow C CXX)
@@ -352,9 +362,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
- else (tensorflow_ENABLE_MKLDNN_SUPPORT)
- add_definitions(-DINTEL_MKL_ML_ONLY)
- endif()
+ endif(tensorflow_ENABLE_MKLDNN_SUPPORT)
endif (tensorflow_ENABLE_MKL_SUPPORT)
if (tensorflow_ENABLE_GPU)
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 77242b34fd..84c679162c 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -108,180 +108,177 @@ ops or APIs.
Step-by-step Windows build
==========================
-1. Install the prerequisites detailed above, and set up your environment.
-
- * The following commands assume that you are using the Windows Command
- Prompt (`cmd.exe`). You will need to set up your environment to use the
- appropriate toolchain, i.e. the 64-bit tools. (Some of the binary targets
- we will build are too large for the 32-bit tools, and they will fail with
- out-of-memory errors.) The typical command to do set up your
- environment is:
-
- ```
- D:\temp> "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"
- ```
-
- * When building with GPU support after installing the CUDNN zip file from NVidia, append its
- bin directory to your PATH environment variable.
- In case TensorFlow fails to find the CUDA dll's during initialization, check your PATH environment variable.
- It should contain the directory of the CUDA dlls and the directory of the CUDNN dll.
- For example:
-
- ```
- D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin
- D:\local\cuda\bin
- ```
-
- * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable.
-
- In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable.
- It should contain the directory of the MKL dlls. For example:
-
- ```
- D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl
- D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler
- D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt
- ```
-
-
- * We assume that `cmake` and `git` are installed and in your `%PATH%`. If
- for example `cmake` is not in your path and it is installed in
- `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory
- to your `%PATH%` as follows:
-
- ```
- D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe"
- ```
-
-2. Clone the TensorFlow repository and create a working directory for your
- build:
-
- ```
- D:\temp> git clone https://github.com/tensorflow/tensorflow.git
- D:\temp> cd tensorflow\tensorflow\contrib\cmake
- D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build
- D:\temp\tensorflow\tensorflow\contrib\cmake> cd build
- D:\temp\tensorflow\tensorflow\contrib\cmake\build>
- ```
-
-3. Invoke CMake to create Visual Studio solution and project files.
-
- **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment
- variable. The other paths are for illustrative purposes only, and may
- be different on your platform. The `^` character is a line continuation
- and must be the last character on each line.
-
- ```
- D:\...\build> cmake .. -A x64 -DCMAKE_BUILD_TYPE=Release ^
- More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^
- More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^
- More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib
- ```
- To build with GPU support add "^" at the end of the last line above following with:
- ```
- More? -Dtensorflow_ENABLE_GPU=ON ^
- More? -DCUDNN_HOME="D:\...\cudnn"
- ```
- To build with MKL support add "^" at the end of the last line above following with:
-
- ```
- More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^
- More? -DMKL_HOME="D:\...\compilers_and_libraries"
- ```
-
- To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows:
-
- ```
- More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX
- ```
-
- Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build
- configuration that you choose when invoking `msbuild`. The known-good
- values are `Release` and `RelWithDebInfo`. The `Debug` build type is
- not currently supported, because it relies on a `Debug` library for
- Python (`python35d.lib`) that is not distributed by default.
-
- There are various options that can be specified when generating the
- solution and project files:
-
- * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the
- `CMAKE_BUILD_TYPE` option must match the build configuration that you
- choose when invoking MSBuild in step 4. The known-good values are
- `Release` and `RelWithDebInfo`. The `Debug` build type is not currently
- supported, because it relies on a `Debug` library for Python
- (`python35d.lib`) that is not distributed by default.
-
- * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can
- build a small subset of the kernels for a faster build by setting this
- option to `OFF`.
-
- * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate
- project files for a simple C++
- [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc).
-
- * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. Generate
- project files for building a PIP package containing the TensorFlow runtime
- and its Python bindings.
-
- * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include
- gRPC support and the distributed client and server code in the TensorFlow
- runtime.
-
- * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include
- SSL support (for making secure HTTP requests) in the TensorFlow runtime.
- This support is incomplete, and will be used for Google Cloud Storage
- support.
-
- * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include
- GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1.
- CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unzipped_cudnn.
-
- * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests.
- There are many of them and building will take a few hours.
- After cmake, build and execute the tests with
- ```
- MSBuild /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj
- ctest -C RelWithDebInfo
- ```
-
- * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python kernel tests.
- After building the python wheel, you need to install the new wheel before running the tests.
- To execute the tests, use
- ```
- ctest -C RelWithDebInfo
- ```
-
- * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on
- serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
- After building the python wheel, you need to install the new wheel before running the tests.
- To execute the tests, use
- ```
- ctest -C RelWithDebInfo
- ```
-
- * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl).
- CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl.
-
- * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support.
-
-
-4. Invoke MSBuild to build TensorFlow.
-
- To build the C++ example program, which will be created as a `.exe`
- executable in the subdirectory `.\Release`:
-
- ```
- D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj
- D:\...\build> Release\tf_tutorials_example_trainer.exe
- ```
-
- To build the PIP package, which will be created as a `.whl` file in the
- subdirectory `.\tf_python\dist`:
-
- ```
- D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj
- ```
-
+1. Install the prerequisites detailed above, and set up your environment.
+
+ * When building with GPU support after installing the CUDNN zip file from
+ NVidia, append its bin directory to your PATH environment variable. In
+ case TensorFlow fails to find the CUDA dll's during initialization,
+ check your PATH environment variable. It should contain the directory of
+ the CUDA dlls and the directory of the CUDNN dll. For example:
+
+ ```
+ D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin
+ D:\local\cuda\bin
+ ```
+
+ * When building with MKL support after installing
+ [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin
+ directories to your PATH environment variable.
+
+ In case TensorFlow fails to find the MKL dll's during initialization,
+ check your PATH environment variable. It should contain the directory of
+ the MKL dlls. For example:
+
+ ```
+ D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl
+ D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler
+ D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt
+ ```
+
+ * We assume that `cmake` and `git` are installed and in your `%PATH%`. If
+ for example `cmake` is not in your path and it is installed in
+ `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory
+ to your `%PATH%` as follows:
+
+ ```
+ D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe"
+ ```
+
+2. Clone the TensorFlow repository and create a working directory for your
+ build:
+
+ ```
+ D:\temp> git clone https://github.com/tensorflow/tensorflow.git
+ D:\temp> cd tensorflow\tensorflow\contrib\cmake
+ D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build
+ D:\temp\tensorflow\tensorflow\contrib\cmake> cd build
+ D:\temp\tensorflow\tensorflow\contrib\cmake\build>
+ ```
+
+3. Invoke CMake to create Visual Studio solution and project files.
+
+ **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment
+ variable. The other paths are for illustrative purposes only, and may be
+ different on your platform. The `^` character is a line continuation and
+ must be the last character on each line.
+
+ ```
+ D:\...\build> cmake .. -A x64 -Thost=x64 -DCMAKE_BUILD_TYPE=Release ^
+ More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^
+ More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^
+ More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib
+ ```
+
+ To build with GPU support add "^" at the end of the last line above
+ following with: `More? -Dtensorflow_ENABLE_GPU=ON ^ More?
+ -DCUDNN_HOME="D:\...\cudnn"` To build with MKL support add "^" at the end of
+ the last line above following with:
+
+ ```
+ More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^
+ More? -DMKL_HOME="D:\...\compilers_and_libraries"
+ ```
+
+ To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows:
+
+ ```
+ More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX
+ ```
+
+ Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build
+ configuration that you choose when invoking `msbuild`. The known-good values
+ are `Release` and `RelWithDebInfo`. The `Debug` build type is not currently
+ supported, because it relies on a `Debug` library for Python
+ (`python35d.lib`) that is not distributed by default.
+
+ The `-Thost=x64` flag will ensure that the 64 bit compiler and linker is
+ used when building. Without this flag, MSBuild will use the 32 bit toolchain
+ which is prone to compile errors such as "compiler out of heap space".
+
+ There are various options that can be specified when generating the solution
+ and project files:
+
+ * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the
+ `CMAKE_BUILD_TYPE` option must match the build configuration that you
+ choose when invoking MSBuild in step 4. The known-good values are
+ `Release` and `RelWithDebInfo`. The `Debug` build type is not currently
+ supported, because it relies on a `Debug` library for Python
+ (`python35d.lib`) that is not distributed by default.
+
+ * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can
+ build a small subset of the kernels for a faster build by setting this
+ option to `OFF`.
+
+ * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate
+ project files for a simple C++
+ [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc).
+
+ * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`.
+ Generate project files for building a PIP package containing the
+ TensorFlow runtime and its Python bindings.
+
+ * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include
+ gRPC support and the distributed client and server code in the
+ TensorFlow runtime.
+
+ * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include
+ SSL support (for making secure HTTP requests) in the TensorFlow runtime.
+ This support is incomplete, and will be used for Google Cloud Storage
+ support.
+
+ * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include GPU
+ support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and
+ CUDNN 5.1. CMake will expect the location of CUDNN in
+ -DCUDNN_HOME=path_you_unzipped_cudnn.
+
+ * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds
+ cc unit tests. There are many of them and building will take a few
+ hours. After cmake, build and execute the tests with `MSBuild
+ /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj ctest -C
+ RelWithDebInfo`
+
+ * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This
+ enables python kernel tests. After building the python wheel, you need
+ to install the new wheel before running the tests. To execute the tests,
+ use `ctest -C RelWithDebInfo`
+
+ * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This
+ enables python tests on serveral major packages. This option is only
+ valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
+ After building the python wheel, you need to install the new wheel
+ before running the tests. To execute the tests, use `ctest -C
+ RelWithDebInfo`
+
+ * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include
+ MKL support. If MKL is enabled you need to install the
+ [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). CMake
+ will expect the location of MKL in -MKL_HOME=path_you_install_mkl.
+
+ * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`.
+ Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for
+ Deep Neural Networks (Intel(R)
+ MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add
+ `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support.
+
+4. Invoke MSBuild to build TensorFlow.
+
+ Set up the path to find MSbuild: `D:\temp> "C:\Program Files (x86)\Microsoft
+ Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"`
+
+ To build the C++ example program, which will be created as a `.exe`
+ executable in the subdirectory `.\Release`:
+
+ ```
+ D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj
+ D:\...\build> Release\tf_tutorials_example_trainer.exe
+ ```
+
+ To build the PIP package, which will be created as a `.whl` file in the
+ subdirectory `.\tf_python\dist`:
+
+ ```
+ D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj
+ ```
Linux Continuous Integration build
==================================
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 8267612236..76d5b59ce1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -412,6 +412,24 @@ cuda_py_test(
)
cuda_py_test(
+ name = "moving_averages_test",
+ srcs = ["moving_averages_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
name = "optimizer_v2_test",
srcs = ["optimizer_v2_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index cff4b0a463..63a163e76c 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -349,26 +349,26 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
required_gpus=2)
-adam_optimizer_v1_fn = NamedObject(
- "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+adam_optimizer_v1_fn = NamedObject("AdamV1",
+ lambda: adam.AdamOptimizer(0.001, epsilon=1))
rmsprop_optimizer_v1_fn = NamedObject(
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
- adagrad_optimizer_v1_fn]
-adam_optimizer_v2_fn = NamedObject(
- "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
+
gradient_descent_optimizer_v2_fn = NamedObject(
"GradientDescentV2",
lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
adagrad_optimizer_v2_fn = NamedObject(
"AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
- adagrad_optimizer_v2_fn]
+adam_optimizer_v2_fn = NamedObject(
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+
+optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn]
graph_and_eager_modes = ["graph", "eager"]
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index a84ef04196..da7f8c548f 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -113,7 +113,7 @@ def main(_):
distribute=strategy)
# Train the model with the train dataset.
- model.fit(x=train_ds, epochs=20, steps_per_epoch=310)
+ model.fit(x=train_ds, epochs=20, steps_per_epoch=468)
# Evaluate the model with the eval dataset.
score = model.evaluate(eval_ds, steps=10, verbose=0)
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index ba147e7824..60e134055f 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -179,11 +179,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def get_expected_variables(optimizer_fn, num_parameter_devices):
variables_map = {
"GradientDescent": ["dense/kernel", "dense/bias"],
- "Adam": [
- "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
- "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
- "dense/bias/Adam_1"
- ],
"Adagrad": [
"dense/kernel/Adagrad", "dense/kernel",
"dense/bias/Adagrad", "dense/bias"
diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py
new file mode 100644
index 0000000000..119352ad91
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/moving_averages_test.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training.moving_averages when using a DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import moving_averages
+
+
+all_combinations = combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu],
+ mode=["graph"])
+
+
+class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_combinations)
+ def testTowerModeWithoutZeroDebias(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+ return var, assign
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(distribution.unwrap(assign))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testTowerMode(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([0.0, 0.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ return var, assign.op
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign_op = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(distribution.unwrap(assign_op))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ self.assertAllClose(average_val, var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTowerWithoutZeroDebias(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0])
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(assign)
+ average_val = [1.0, 2.0]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+ # Also try assign.op.
+ sess.run(assign.op)
+ orig_weight = 0.25 * 0.25
+ val_weight = 1.0 - orig_weight
+ self.assertAllClose(
+ [10.0 * orig_weight + average_val[0] * val_weight,
+ 11.0 * orig_weight + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTower(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([0.0, 0.0])
+ val = array_ops.placeholder(dtypes.float32)
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(var, val, decay)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(assign, feed_dict={val: [1.0, 2.0]})
+ self.assertAllClose([1.0, 2.0], var.eval())
+
+ # Also try assign.op.
+ sess.run(assign.op, feed_dict={val: [10.0, 0.0]})
+ self.assertAllClose(
+ [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0),
+ (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
+ var.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 353d11a583..9c112e4f85 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -262,7 +262,9 @@ class ParameterServerStrategyTestBase(
h = f + 1.0
self.assertEqual(
device_util.canonicalize(u.device), tower_variable_device)
- self.assertEqual(device_util.canonicalize(x.device), h.device)
+ self.assertEqual(
+ device_util.canonicalize(x.device),
+ device_util.canonicalize(h.device))
return y_add, z_add, f
y, z, f = d.call_for_each_tower(model_fn)
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index a1f1c5f3d7..b131ed4f12 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -75,7 +75,7 @@ class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint:
layer.
head: the `Head` instance defined for Estimator.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -199,7 +199,7 @@ def boosted_trees_classifier_train_in_memory(
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
n_classes: number of label classes. Default is binary classification.
Multiclass support is not yet implemented.
@@ -345,7 +345,7 @@ def boosted_trees_regressor_train_in_memory(
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
label_dimension: Number of regression targets per example.
Multi-dimensional support is not yet implemented.
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
index 724bc2c82f..4e7965ef26 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
@@ -118,7 +118,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
head: A `_Head` instance constructed with a method such as
`tf.contrib.estimator.multi_label_head`.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
linear_feature_columns: An iterable containing all the feature columns
used by linear part of the model. All items in the set must be
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 6ca7aaf989..40a91175b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -248,7 +248,7 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
model. All items in the set should be instances of classes derived from
`_FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can also
- be used to load checkpoints from the directory into a estimator to
+ be used to load checkpoints from the directory into an estimator to
continue training a previously saved model.
n_classes: Number of label classes. Defaults to 2, namely binary
classification. Must be > 1.
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 98660bb731..c595f47395 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
@@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'):
return rnn_cell_fn
-def _concatenate_context_input(sequence_input, context_input):
- """Replicates `context_input` across all timesteps of `sequence_input`.
-
- Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
- This value is appended to `sequence_input` on dimension 2 and the result is
- returned.
-
- Args:
- sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
- padded_length, d0]`.
- context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
-
- Returns:
- A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
- d0 + d1]`.
-
- Raises:
- ValueError: If `sequence_input` does not have rank 3 or `context_input` does
- not have rank 2.
- """
- seq_rank_check = check_ops.assert_rank(
- sequence_input,
- 3,
- message='sequence_input must have rank 3',
- data=[array_ops.shape(sequence_input)])
- seq_type_check = check_ops.assert_type(
- sequence_input,
- dtypes.float32,
- message='sequence_input must have dtype float32; got {}.'.format(
- sequence_input.dtype))
- ctx_rank_check = check_ops.assert_rank(
- context_input,
- 2,
- message='context_input must have rank 2',
- data=[array_ops.shape(context_input)])
- ctx_type_check = check_ops.assert_type(
- context_input,
- dtypes.float32,
- message='context_input must have dtype float32; got {}.'.format(
- context_input.dtype))
- with ops.control_dependencies(
- [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
- padded_length = array_ops.shape(sequence_input)[1]
- tiled_context_input = array_ops.tile(
- array_ops.expand_dims(context_input, 1),
- array_ops.concat([[1], [padded_length], [1]], 0))
- return array_ops.concat([sequence_input, tiled_context_input], 2)
-
-
def _select_last_activations(activations, sequence_lengths):
"""Selects the nth set of activations for each n in `sequence_length`.
@@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns,
context_input = feature_column_lib.input_layer(
features=features,
feature_columns=context_feature_columns)
- sequence_input = _concatenate_context_input(sequence_input,
- context_input)
+ sequence_input = seq_fc.concatenate_context_input(
+ context_input, sequence_input)
cell = rnn_cell_fn(mode)
# Ignore output state.
diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD
index aab7d0c9e8..a926ffd598 100644
--- a/tensorflow/contrib/feature_column/BUILD
+++ b/tensorflow/contrib/feature_column/BUILD
@@ -27,6 +27,7 @@ py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:tensor_shape",
@@ -46,9 +47,29 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
"//tensorflow/python/feature_column",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "sequence_feature_column_integration_test",
+ srcs = ["python/feature_column/sequence_feature_column_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":sequence_feature_column",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/keras:layers",
],
)
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index 05bcdac2ca..dd6da35ed0 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
# pylint: disable=protected-access
-# TODO(b/73827486): Support SequenceExample.
def sequence_input_layer(
@@ -110,6 +109,7 @@ def sequence_input_layer(
output_tensors = []
sequence_lengths = []
ordered_columns = []
+
for column in sorted(feature_columns, key=lambda x: x.name):
ordered_columns.append(column)
with variable_scope.variable_scope(
@@ -121,17 +121,67 @@ def sequence_input_layer(
# Flattens the final dimension to produce a 3D Tensor.
num_elements = column._variable_shape.num_elements()
shape = array_ops.shape(dense_tensor)
+ target_shape = [shape[0], shape[1], num_elements]
output_tensors.append(
- array_ops.reshape(
- dense_tensor,
- shape=array_ops.concat([shape[:2], [num_elements]], axis=0)))
+ array_ops.reshape(dense_tensor, shape=target_shape))
sequence_lengths.append(sequence_length)
+
fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
sequence_length = _assert_all_equal_and_return(sequence_lengths)
+
return array_ops.concat(output_tensors, -1), sequence_length
+def concatenate_context_input(context_input, sequence_input):
+ """Replicates `context_input` across all timesteps of `sequence_input`.
+
+ Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
+ This value is appended to `sequence_input` on dimension 2 and the result is
+ returned.
+
+ Args:
+ context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
+ sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
+ padded_length, d0]`.
+
+ Returns:
+ A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
+ d0 + d1]`.
+
+ Raises:
+ ValueError: If `sequence_input` does not have rank 3 or `context_input` does
+ not have rank 2.
+ """
+ seq_rank_check = check_ops.assert_rank(
+ sequence_input,
+ 3,
+ message='sequence_input must have rank 3',
+ data=[array_ops.shape(sequence_input)])
+ seq_type_check = check_ops.assert_type(
+ sequence_input,
+ dtypes.float32,
+ message='sequence_input must have dtype float32; got {}.'.format(
+ sequence_input.dtype))
+ ctx_rank_check = check_ops.assert_rank(
+ context_input,
+ 2,
+ message='context_input must have rank 2',
+ data=[array_ops.shape(context_input)])
+ ctx_type_check = check_ops.assert_type(
+ context_input,
+ dtypes.float32,
+ message='context_input must have dtype float32; got {}.'.format(
+ context_input.dtype))
+ with ops.control_dependencies(
+ [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
+ padded_length = array_ops.shape(sequence_input)[1]
+ tiled_context_input = array_ops.tile(
+ array_ops.expand_dims(context_input, 1),
+ array_ops.concat([[1], [padded_length], [1]], 0))
+ return array_ops.concat([sequence_input, tiled_context_input], 2)
+
+
def sequence_categorical_column_with_identity(
key, num_buckets, default_value=None):
"""Returns a feature column that represents sequences of integers.
@@ -453,9 +503,17 @@ class _SequenceNumericColumn(
[array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape],
axis=0)
dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
- sequence_length = fc._sequence_length_from_sparse_tensor(
- sp_tensor, num_elements=self._variable_shape.num_elements())
+
+ # Get the number of timesteps per example
+ # For the 2D case, the raw values are grouped according to num_elements;
+ # for the 3D case, the grouping happens in the third dimension, and
+ # sequence length is not affected.
+ num_elements = (self._variable_shape.num_elements()
+ if sp_tensor.shape.ndims == 2 else 1)
+ seq_length = fc._sequence_length_from_sparse_tensor(
+ sp_tensor, num_elements=num_elements)
+
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
- dense_tensor=dense_tensor, sequence_length=sequence_length)
+ dense_tensor=dense_tensor, sequence_length=seq_length)
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py
new file mode 100644
index 0000000000..d8ca363627
--- /dev/null
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py
@@ -0,0 +1,280 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for sequence feature columns with SequenceExamples."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import string
+import tempfile
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.keras.layers import recurrent
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class SequenceFeatureColumnIntegrationTest(test.TestCase):
+
+ def _make_sequence_example(self):
+ example = example_pb2.SequenceExample()
+ example.context.feature['int_ctx'].int64_list.value.extend([5])
+ example.context.feature['float_ctx'].float_list.value.extend([123.6])
+ for val in range(0, 10, 2):
+ feat = feature_pb2.Feature()
+ feat.int64_list.value.extend([val] * val)
+ example.feature_lists.feature_list['int_list'].feature.extend([feat])
+ for val in range(1, 11, 2):
+ feat = feature_pb2.Feature()
+ feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val)
+ example.feature_lists.feature_list['str_list'].feature.extend([feat])
+
+ return example
+
+ def _build_feature_columns(self):
+ col = fc.categorical_column_with_identity(
+ 'int_ctx', num_buckets=100)
+ ctx_cols = [
+ fc.embedding_column(col, dimension=10),
+ fc.numeric_column('float_ctx')]
+
+ identity_col = sfc.sequence_categorical_column_with_identity(
+ 'int_list', num_buckets=10)
+ bucket_col = sfc.sequence_categorical_column_with_hash_bucket(
+ 'bytes_list', hash_bucket_size=100)
+ seq_cols = [
+ fc.embedding_column(identity_col, dimension=10),
+ fc.embedding_column(bucket_col, dimension=20)]
+
+ return ctx_cols, seq_cols
+
+ def test_sequence_example_into_input_layer(self):
+ examples = [_make_sequence_example().SerializeToString()] * 100
+ ctx_cols, seq_cols = self._build_feature_columns()
+
+ def _parse_example(example):
+ ctx, seq = parsing_ops.parse_single_sequence_example(
+ example,
+ context_features=fc.make_parse_example_spec(ctx_cols),
+ sequence_features=fc.make_parse_example_spec(seq_cols))
+ ctx.update(seq)
+ return ctx
+
+ ds = dataset_ops.Dataset.from_tensor_slices(examples)
+ ds = ds.map(_parse_example)
+ ds = ds.batch(20)
+
+ # Test on a single batch
+ features = ds.make_one_shot_iterator().get_next()
+
+ # Tile the context features across the sequence features
+ seq_layer, _ = sfc.sequence_input_layer(features, seq_cols)
+ ctx_layer = fc.input_layer(features, ctx_cols)
+ input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
+
+ rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10))
+ output = rnn_layer(input_layer)
+
+ with self.cached_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ features_r = sess.run(features)
+ self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6])
+
+ output_r = sess.run(output)
+ self.assertAllEqual(output_r.shape, [20, 10])
+
+
+class SequenceExampleParsingTest(test.TestCase):
+
+ def test_seq_ex_in_sequence_categorical_column_with_identity(self):
+ self._test_parsed_sequence_example(
+ 'int_list', sfc.sequence_categorical_column_with_identity,
+ 10, [3, 6], [2, 4, 6])
+
+ def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self):
+ self._test_parsed_sequence_example(
+ 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket,
+ 10, [3, 4], [compat.as_bytes(x) for x in 'acg'])
+
+ def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self):
+ self._test_parsed_sequence_example(
+ 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list,
+ list(string.ascii_lowercase), [3, 4],
+ [compat.as_bytes(x) for x in 'acg'])
+
+ def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self):
+ _, fname = tempfile.mkstemp()
+ with open(fname, 'w') as f:
+ f.write(string.ascii_lowercase)
+ self._test_parsed_sequence_example(
+ 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file,
+ fname, [3, 4], [compat.as_bytes(x) for x in 'acg'])
+
+ def _test_parsed_sequence_example(
+ self, col_name, col_fn, col_arg, shape, values):
+ """Helper function to check that each FeatureColumn parses correctly.
+
+ Args:
+ col_name: string, name to give to the feature column. Should match
+ the name that the column will parse out of the features dict.
+ col_fn: function used to create the feature column. For example,
+ sequence_numeric_column.
+ col_arg: second arg that the target feature column is expecting.
+ shape: the expected dense_shape of the feature after parsing into
+ a SparseTensor.
+ values: the expected values at index [0, 2, 6] of the feature
+ after parsing into a SparseTensor.
+ """
+ example = _make_sequence_example()
+ columns = [
+ fc.categorical_column_with_identity('int_ctx', num_buckets=100),
+ fc.numeric_column('float_ctx'),
+ col_fn(col_name, col_arg)
+ ]
+ context, seq_features = parsing_ops.parse_single_sequence_example(
+ example.SerializeToString(),
+ context_features=fc.make_parse_example_spec(columns[:2]),
+ sequence_features=fc.make_parse_example_spec(columns[2:]))
+
+ with self.cached_session() as sess:
+ ctx_result, seq_result = sess.run([context, seq_features])
+ self.assertEqual(list(seq_result[col_name].dense_shape), shape)
+ self.assertEqual(
+ list(seq_result[col_name].values[[0, 2, 6]]), values)
+ self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1])
+ self.assertEqual(ctx_result['int_ctx'].values[0], 5)
+ self.assertEqual(list(ctx_result['float_ctx'].shape), [1])
+ self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1)
+
+
+_SEQ_EX_PROTO = """
+context {
+ feature {
+ key: "float_ctx"
+ value {
+ float_list {
+ value: 123.6
+ }
+ }
+ }
+ feature {
+ key: "int_ctx"
+ value {
+ int64_list {
+ value: 5
+ }
+ }
+ }
+}
+feature_lists {
+ feature_list {
+ key: "bytes_list"
+ value {
+ feature {
+ bytes_list {
+ value: "a"
+ }
+ }
+ feature {
+ bytes_list {
+ value: "b"
+ value: "c"
+ }
+ }
+ feature {
+ bytes_list {
+ value: "d"
+ value: "e"
+ value: "f"
+ value: "g"
+ }
+ }
+ }
+ }
+ feature_list {
+ key: "float_list"
+ value {
+ feature {
+ float_list {
+ value: 1.0
+ }
+ }
+ feature {
+ float_list {
+ value: 3.0
+ value: 3.0
+ value: 3.0
+ }
+ }
+ feature {
+ float_list {
+ value: 5.0
+ value: 5.0
+ value: 5.0
+ value: 5.0
+ value: 5.0
+ }
+ }
+ }
+ }
+ feature_list {
+ key: "int_list"
+ value {
+ feature {
+ int64_list {
+ value: 2
+ value: 2
+ }
+ }
+ feature {
+ int64_list {
+ value: 4
+ value: 4
+ value: 4
+ value: 4
+ }
+ }
+ feature {
+ int64_list {
+ value: 6
+ value: 6
+ value: 6
+ value: 6
+ value: 6
+ value: 6
+ }
+ }
+ }
+ }
+}
+"""
+
+
+def _make_sequence_example():
+ example = example_pb2.SequenceExample()
+ return text_format.Parse(_SEQ_EX_PROTO, example)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 45d7b74046..929e83523a 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
@@ -28,28 +29,61 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
-class SequenceInputLayerTest(test.TestCase):
+class SequenceInputLayerTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [2, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2)),
+ 'expected_input_layer': [
+ # example 0, ids_a [2], ids_b [1]
+ [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [2, 0]
+ [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],],
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[2], [0, 1]]
+ # feature 1, ids [[0, 0], [1]]
+ indices=(
+ (0, 0, 0), (0, 1, 0), (0, 1, 1),
+ (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 0, 0, 1),
+ dense_shape=(2, 2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[1, 1], [1]]
+ # feature 1, ids [[2], [0]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(1, 1, 1, 2, 0),
+ dense_shape=(2, 2, 2)),
+ 'expected_input_layer': [
+ # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
+ [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]],
+ # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -]
+ [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]],
+ 'expected_sequence_length': [2, 2]},
+ )
+ def test_embedding_column(
+ self, sparse_input_a, sparse_input_b, expected_input_layer,
+ expected_sequence_length):
- def test_embedding_column(self):
vocabulary_size = 3
- sparse_input_a = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
- sparse_input_b = sparse_tensor.SparseTensorValue(
- # example 0, ids [1]
- # example 1, ids [2, 0]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 2, 0),
- dense_shape=(2, 2))
-
embedding_dimension_a = 2
embedding_values_a = (
(1., 2.), # id 0
@@ -70,14 +104,6 @@ class SequenceInputLayerTest(test.TestCase):
return embedding_values
return _initializer
- expected_input_layer = [
- # example 0, ids_a [2], ids_b [1]
- [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
- # example 1, ids_a [0, 1], ids_b [2, 0]
- [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],
- ]
- expected_sequence_length = [1, 2]
-
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column_a = fc.embedding_column(
@@ -233,29 +259,53 @@ class SequenceInputLayerTest(test.TestCase):
},
feature_columns=shared_embedding_columns)
- def test_indicator_column(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [1, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 1, 0),
+ dense_shape=(2, 2)),
+ 'expected_input_layer': [
+ # example 0, ids_a [2], ids_b [1]
+ [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [1, 0]
+ [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[2], [0, 1]]
+ # feature 1, ids [[0, 0], [1]]
+ indices=(
+ (0, 0, 0), (0, 1, 0), (0, 1, 1),
+ (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 0, 0, 1),
+ dense_shape=(2, 2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[1, 1], [1]]
+ # feature 1, ids [[1], [0]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(1, 1, 1, 1, 0),
+ dense_shape=(2, 2, 2)),
+ 'expected_input_layer': [
+ # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
+ [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]],
+ # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -]
+ [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
+ 'expected_sequence_length': [2, 2]},
+ )
+ def test_indicator_column(
+ self, sparse_input_a, sparse_input_b, expected_input_layer,
+ expected_sequence_length):
vocabulary_size_a = 3
- sparse_input_a = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
vocabulary_size_b = 2
- sparse_input_b = sparse_tensor.SparseTensorValue(
- # example 0, ids [1]
- # example 1, ids [1, 0]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 1, 0),
- dense_shape=(2, 2))
-
- expected_input_layer = [
- # example 0, ids_a [2], ids_b [1]
- [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]],
- # example 1, ids_a [0, 1], ids_b [1, 0]
- [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]],
- ]
- expected_sequence_length = [1, 2]
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size_a)
@@ -298,18 +348,32 @@ class SequenceInputLayerTest(test.TestCase):
features={'aaa': sparse_input},
feature_columns=[indicator_column_a])
- def test_numeric_column(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0.], [1]]
- # example 1, [[10.]]
- indices=((0, 0), (0, 1), (1, 0)),
- values=(0., 1., 10.),
- dense_shape=(2, 2))
- expected_input_layer = [
- [[0.], [1.]],
- [[10.], [0.]],
- ]
- expected_sequence_length = [2, 1]
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [0., 1]
+ # example 1, [10.]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2)),
+ 'expected_input_layer': [
+ [[0.], [1.]],
+ [[10.], [0.]]],
+ 'expected_sequence_length': [2, 1]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[20, 3], [5]]
+ # feature 1, ids [[3], [8]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(20, 3, 5., 3., 8.),
+ dense_shape=(2, 2, 2)),
+ 'expected_input_layer': [
+ [[20.], [3.], [5.], [0.]],
+ [[3.], [0.], [8.], [0.]]],
+ 'expected_sequence_length': [2, 2]},
+ )
+ def test_numeric_column(
+ self, sparse_input, expected_input_layer, expected_sequence_length):
numeric_column = sfc.sequence_numeric_column('aaa')
input_layer, sequence_length = sfc.sequence_input_layer(
@@ -321,21 +385,38 @@ class SequenceInputLayerTest(test.TestCase):
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
- def test_numeric_column_multi_dim(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.]
+ # example 1, [10., 11., 12., 13.]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
+ (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8)),
+ 'expected_input_layer': [
+ # The output of numeric_column._get_dense_tensor should be flattened.
+ [[0., 1., 2., 3.], [4., 5., 6., 7.]],
+ [[10., 11., 12., 13.], [0., 0., 0., 0.]]],
+ 'expected_sequence_length': [2, 1]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
+ # example 1, [[10., 11., 12., 13.], []]
+ indices=((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
+ (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3),
+ (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 2, 4)),
+ 'expected_input_layer': [
+ # The output of numeric_column._get_dense_tensor should be flattened.
+ [[0., 1., 2., 3.], [4., 5., 6., 7.]],
+ [[10., 11., 12., 13.], [0., 0., 0., 0.]]],
+ 'expected_sequence_length': [2, 1]},
+ )
+ def test_numeric_column_multi_dim(
+ self, sparse_input, expected_input_layer, expected_sequence_length):
"""Tests sequence_input_layer for multi-dimensional numeric_column."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
- # example 1, [[[10., 11.], [12., 13.]]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
- (1, 0), (1, 1), (1, 2), (1, 3)),
- values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
- dense_shape=(2, 8))
- # The output of numeric_column._get_dense_tensor should be flattened.
- expected_input_layer = [
- [[0., 1., 2., 3.], [4., 5., 6., 7.]],
- [[10., 11., 12., 13.], [0., 0., 0., 0.]],
- ]
- expected_sequence_length = [2, 1]
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
input_layer, sequence_length = sfc.sequence_input_layer(
@@ -377,6 +458,134 @@ class SequenceInputLayerTest(test.TestCase):
r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'):
sess.run(sequence_length)
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
+ # example 1, [[[10., 11.], [12., 13.]]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
+ (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8)),
+ 'expected_shape': [2, 2, 4]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
+ # example 1, [[10., 11., 12., 13.], []]
+ indices=((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
+ (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2),
+ (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 2, 4)),
+ 'expected_shape': [2, 2, 4]},
+ )
+ def test_static_shape_from_tensors_numeric(
+ self, sparse_input, expected_shape):
+ """Tests that we return a known static shape when we have one."""
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
+
+ input_layer, _ = sfc.sequence_input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[numeric_column])
+ shape = input_layer.get_shape()
+ self.assertEqual(shape, expected_shape)
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2)),
+ 'expected_shape': [4, 2, 3]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ # example 2, ids []
+ # example 3, ids [[1], [0, 2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
+ (3, 0, 0), (3, 1, 0), (3, 1, 1)),
+ values=(2, 0, 1, 2, 1, 0, 2),
+ dense_shape=(4, 2, 2)),
+ 'expected_shape': [4, 2, 3]}
+ )
+ def test_static_shape_from_tensors_indicator(
+ self, sparse_input, expected_shape):
+ """Tests that we return a known static shape when we have one."""
+ categorical_column = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ indicator_column = fc.indicator_column(categorical_column)
+
+ input_layer, _ = sfc.sequence_input_layer(
+ features={'aaa': sparse_input}, feature_columns=[indicator_column])
+ shape = input_layer.get_shape()
+ self.assertEqual(shape, expected_shape)
+
+
+class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase):
+ """Tests the utility fn concatenate_context_input."""
+
+ def test_concatenate_context_input(self):
+ seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2))
+ context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ input_layer = sfc.concatenate_context_input(context_input, seq_input)
+
+ expected = np.array([
+ [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]],
+ [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]]
+ ], dtype=np.float32)
+ with monitored_session.MonitoredSession() as sess:
+ output = sess.run(input_layer)
+ self.assertAllEqual(expected, output)
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'rank_lt_3',
+ 'seq_input': ops.convert_to_tensor(np.arange(100).reshape(10, 10))},
+ {'testcase_name': 'rank_gt_3',
+ 'seq_input': ops.convert_to_tensor(np.arange(100).reshape(5, 5, 2, 2))}
+ )
+ def test_sequence_input_throws_error(self, seq_input):
+ context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'rank_lt_2',
+ 'context_input': ops.convert_to_tensor(np.arange(100))},
+ {'testcase_name': 'rank_gt_2',
+ 'context_input': ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))}
+ )
+ def test_context_input_throws_error(self, context_input):
+ seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
+ def test_integer_seq_input_throws_error(self):
+ seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))
+ context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10))
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(
+ TypeError, 'sequence_input must have dtype float32'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
+ def test_integer_context_input_throws_error(self):
+ seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))
+ context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(
+ TypeError, 'context_input must have dtype float32'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
class InputLayerTest(test.TestCase):
"""Tests input_layer with sequence feature columns."""
@@ -443,75 +652,79 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual):
test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
-class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
-
- def test_get_sparse_tensors(self):
- column = sfc.sequence_categorical_column_with_identity(
- 'aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 2, 0),
- dense_shape=(2, 2))
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=np.array((1, 2, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
+class SequenceCategoricalColumnWithIdentityTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((1, 2, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=(6, 7, 8),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=(6, 7, 8),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
+ column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9)
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_value(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
-
- def test_get_sparse_tensors_inputs3d(self):
- """Tests _get_sparse_tensors when the input is already 3D Tensor."""
- column = sfc.sequence_categorical_column_with_identity(
- 'aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=(1, 2, 0),
- dense_shape=(2, 2, 1))
-
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r'Column aaa expected ID tensor of rank 2\.\s*'
- r'id_tensor shape:\s*\[2 2 1\]'):
- id_weight_pair = column._get_sparse_tensors(
- _LazyBuilder({'aaa': inputs}))
- with monitored_session.MonitoredSession() as sess:
- id_weight_pair.id_tensor.eval(session=sess)
-
-
-class SequenceCategoricalColumnWithHashBucketTest(test.TestCase):
-
- def test_get_sparse_tensors(self):
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithHashBucketTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ # Ignored to avoid hash dependence in test.
+ values=np.array((0, 0, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ # Ignored to avoid hash dependence in test.
+ values=np.array((0, 0, 0), dtype=np.int64),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
column = sfc.sequence_categorical_column_with_hash_bucket(
'aaa', hash_bucket_size=10)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('omar', 'stringer', 'marlo'),
- dense_shape=(2, 2))
-
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- # Ignored to avoid hash dependence in test.
- values=np.array((0, 0, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_indices_shape(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
-class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
+class SequenceCategoricalColumnWithVocabularyFileTest(
+ test.TestCase, parameterized.TestCase):
def _write_vocab(self, vocab_strings, file_name):
vocab_file = os.path.join(self.get_temp_dir(), file_name)
@@ -527,68 +740,120 @@ class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
'wire_vocabulary.txt')
self._wire_vocabulary_size = 3
- def test_get_sparse_tensors(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=('omar', 'skywalker', 'marlo'),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=np.array((0, -1, 2), dtype=np.int64),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
column = sfc.sequence_categorical_column_with_vocabulary_file(
key='aaa',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=np.array((2, -1, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_value(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
-
-
-class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase):
-
- def test_get_sparse_tensors(self):
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithVocabularyListTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=('omar', 'skywalker', 'marlo'),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=np.array((0, -1, 2), dtype=np.int64),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
column = sfc.sequence_categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'))
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=np.array((2, -1, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_value(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
-
-
-class SequenceEmbeddingColumnTest(test.TestCase):
-
- def test_get_sequence_dense_tensor(self):
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
+
+
+class SequenceEmbeddingColumnTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2)),
+ 'expected': [
+ # example 0, ids [2]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 2.], [3., 5.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [1]
+ [[3., 5.], [0., 0.]]]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ # example 2, ids []
+ # example 3, ids [[1], [0, 2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
+ (3, 0, 0), (3, 1, 0), (3, 1, 1)),
+ values=(2, 0, 1, 2, 1, 0, 2),
+ dense_shape=(4, 2, 2)),
+ 'expected': [
+ # example 0, ids [[2]]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [[0, 1], [2]]
+ [[2, 3.5], [7., 11.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [[1], [0, 2]]
+ [[3., 5.], [4., 6.5]]]}
+ )
+ def test_get_sequence_dense_tensor(self, inputs, expected):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 1), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 2))
-
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
@@ -601,17 +866,6 @@ class SequenceEmbeddingColumnTest(test.TestCase):
self.assertIsNone(partition_info)
return embedding_values
- expected_lookups = [
- # example 0, ids [2]
- [[7., 11.], [0., 0.]],
- # example 1, ids [0, 1]
- [[1., 2.], [3., 5.]],
- # example 2, ids []
- [[0., 0.], [0., 0.]],
- # example 3, ids [1]
- [[3., 5.], [0., 0.]],
- ]
-
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = fc.embedding_column(
@@ -619,24 +873,35 @@ class SequenceEmbeddingColumnTest(test.TestCase):
initializer=_initializer)
embedding_lookup, _ = embedding_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
('embedding_weights:0',), tuple([v.name for v in global_vars]))
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
- self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess))
-
- def test_sequence_length(self):
+ self.assertAllEqual(expected, embedding_lookup.eval(session=sess))
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 2),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2]}
+ )
+ def test_sequence_length(self, inputs, expected_sequence_length):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
@@ -644,7 +909,7 @@ class SequenceEmbeddingColumnTest(test.TestCase):
categorical_column, dimension=2)
_, sequence_length = embedding_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
@@ -855,56 +1120,87 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase):
expected_sequence_length_b, sequence_length_b.eval(session=sess))
-class SequenceIndicatorColumnTest(test.TestCase):
-
- def test_get_sequence_dense_tensor(self):
+class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2)),
+ 'expected': [
+ # example 0, ids [2]
+ [[0., 0., 1.], [0., 0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 0., 0.], [0., 1., 0.]],
+ # example 2, ids []
+ [[0., 0., 0.], [0., 0., 0.]],
+ # example 3, ids [1]
+ [[0., 1., 0.], [0., 0., 0.]]]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ # example 2, ids []
+ # example 3, ids [[1], [2, 2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
+ (3, 0, 0), (3, 1, 0), (3, 1, 1)),
+ values=(2, 0, 1, 2, 1, 2, 2),
+ dense_shape=(4, 2, 2)),
+ 'expected': [
+ # example 0, ids [[2]]
+ [[0., 0., 1.], [0., 0., 0.]],
+ # example 1, ids [[0, 1], [2]]
+ [[1., 1., 0.], [0., 0., 1.]],
+ # example 2, ids []
+ [[0., 0., 0.], [0., 0., 0.]],
+ # example 3, ids [[1], [2, 2]]
+ [[0., 1., 0.], [0., 0., 2.]]]}
+ )
+ def test_get_sequence_dense_tensor(self, inputs, expected):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 1), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 2))
-
- expected_lookups = [
- # example 0, ids [2]
- [[0., 0., 1.], [0., 0., 0.]],
- # example 1, ids [0, 1]
- [[1., 0., 0.], [0., 1., 0.]],
- # example 2, ids []
- [[0., 0., 0.], [0., 0., 0.]],
- # example 3, ids [1]
- [[0., 1., 0.], [0., 0., 0.]],
- ]
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc.indicator_column(categorical_column)
indicator_tensor, _ = indicator_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess))
-
- def test_sequence_length(self):
+ self.assertAllEqual(expected, indicator_tensor.eval(session=sess))
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 2),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2]}
+ )
+ def test_sequence_length(self, inputs, expected_sequence_length):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc.indicator_column(categorical_column)
_, sequence_length = indicator_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
@@ -938,7 +1234,7 @@ class SequenceIndicatorColumnTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
-class SequenceNumericColumnTest(test.TestCase):
+class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase):
def test_defaults(self):
a = sfc.sequence_numeric_column('aaa')
@@ -971,25 +1267,36 @@ class SequenceNumericColumnTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable')
- def test_get_sequence_dense_tensor(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0.], [1]]
- # example 1, [[10.]]
- indices=((0, 0), (0, 1), (1, 0)),
- values=(0., 1., 10.),
- dense_shape=(2, 2))
- expected_dense_tensor = [
- [[0.], [1.]],
- [[10.], [0.]],
- ]
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, values [0., 1]
+ # example 1, [10.]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2)),
+ 'expected': [
+ [[0.], [1.]],
+ [[10.], [0.]]]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[20, 3], [5]]
+ # feature 1, ids [[3], [8]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(20, 3, 5., 3., 8.),
+ dense_shape=(2, 2, 2)),
+ 'expected': [
+ [[20.], [3.], [5.], [0.]],
+ [[3.], [0.], [8.], [0.]]]},
+ )
+ def test_get_sequence_dense_tensor(self, inputs, expected):
numeric_column = sfc.sequence_numeric_column('aaa')
dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_dense_tensor, dense_tensor.eval(session=sess))
+ self.assertAllEqual(expected, dense_tensor.eval(session=sess))
def test_get_sequence_dense_tensor_with_normalizer_fn(self):
@@ -1026,41 +1333,34 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
- def test_get_sequence_dense_tensor_with_shape(self):
- """Tests get_sequence_dense_tensor with shape !=(1,)."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0., 1., 2.], [3., 4., 5.]]
- # example 1, [[10., 11., 12.]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
- (1, 0), (1, 1), (1, 2)),
- values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
- dense_shape=(2, 6))
- expected_dense_tensor = [
- [[0., 1., 2.], [3., 4., 5.]],
- [[10., 11., 12.], [0., 0., 0.]],
- ]
- numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
-
- dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_dense_tensor, dense_tensor.eval(session=sess))
-
- def test_get_dense_tensor_multi_dim(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
+ # example 1, [[[10., 11.], [12., 13.]]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
+ (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8)),
+ 'expected_dense_tensor': [
+ [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]],
+ [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6),
+ (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6),
+ (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 2, 8)),
+ 'expected_dense_tensor': [
+ [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]],
+ [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]],
+ [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]],
+ [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]},
+ )
+ def test_get_dense_tensor_multi_dim(
+ self, sparse_input, expected_dense_tensor):
"""Tests get_sequence_dense_tensor for multi-dim numeric_column."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
- # example 1, [[[10., 11.], [12., 13.]]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
- (1, 0), (1, 1), (1, 2), (1, 3)),
- values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
- dense_shape=(2, 8))
- expected_dense_tensor = [
- [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]],
- [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]],
- ]
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
@@ -1070,43 +1370,55 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
- def test_sequence_length(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0., 1., 2.], [3., 4., 5.]]
- # example 1, [[10., 11., 12.]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
- (1, 0), (1, 1), (1, 2)),
- values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
- dense_shape=(2, 6))
- expected_sequence_length = [2, 1]
- numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2., 0., 1.),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 2],
+ 'shape': (1,)},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2., 0., 1., 2.),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2],
+ 'shape': (1,)},
+ {'testcase_name': '2D_with_shape',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2., 0., 1.),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 1],
+ 'shape': (2,)},
+ {'testcase_name': '3D_with_shape',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2., 0., 1., 2.),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2],
+ 'shape': (2,)},
+ )
+ def test_sequence_length(self, inputs, expected_sequence_length, shape):
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=shape)
_, sequence_length = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
self.assertAllEqual(expected_sequence_length, sequence_length)
self.assertEqual(np.int64, sequence_length.dtype)
- def test_sequence_length_with_shape(self):
- """Tests _sequence_length with shape !=(1,)."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0.], [1]]
- # example 1, [[10.]]
- indices=((0, 0), (0, 1), (1, 0)),
- values=(0., 1., 10.),
- dense_shape=(2, 2))
- expected_sequence_length = [2, 1]
- numeric_column = sfc.sequence_numeric_column('aaa')
-
- _, sequence_length = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
-
def test_sequence_length_with_empty_rows(self):
"""Tests _sequence_length when some examples do not have ids."""
sparse_input = sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index bb06f1c41c..3549cedb70 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <fstream>
#include <list>
#include <map>
-#include <set>
#include <fcntl.h>
#include <rdma/rdma_cma.h>
@@ -30,19 +29,17 @@ limitations under the License.
#include <sys/epoll.h>
#include "tensorflow/contrib/gdr/gdr.pb.h"
-#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/common_runtime/process_state.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#endif // GOOGLE_CUDA
-#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
@@ -70,14 +67,11 @@ bool IsGDRAvailable() {
int TryToReadNumaNode(ibv_device* device) {
#if defined(__APPLE__)
LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
- return 0;
+ return port::kNUMANoAffinity;
#elif defined(PLATFORM_WINDOWS)
// Windows support for NUMA is not currently implemented. Return node 0.
- return 0;
+ return port::kNUMANoAffinity;
#else
- VLOG(2) << "Trying to read NUMA node for device: " << device->name;
- static const int kUnknownNumaNode = -1;
-
auto filename = string(device->ibdev_path) + "/device/numa_node";
std::ifstream ifs(filename.c_str());
@@ -91,12 +85,12 @@ int TryToReadNumaNode(ibv_device* device) {
<< value
<< "), but there must be at least one NUMA node"
", so returning NUMA node zero";
- return 0;
+ return port::kNUMANoAffinity;
}
LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
return value;
}
- return kUnknownNumaNode;
+ return port::kNUMANoAffinity;
#endif
}
@@ -138,8 +132,6 @@ class GdrMemoryManager : public RemoteMemoryManager {
Device* device, DeviceContext* device_context, bool on_host,
StatusCallback done) override;
- static void RegMemVisitors();
-
protected:
Status CreateEndpoint(const string& host, const string& port,
RdmaEndpointPtr& endpoint);
@@ -150,7 +142,8 @@ class GdrMemoryManager : public RemoteMemoryManager {
ibv_mr* FindMemoryRegion(void* addr, size_t length);
- void InsertMemoryRegion(void* addr, size_t length);
+ void InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name);
void EvictMemoryRegion(void* addr, size_t length);
@@ -160,6 +153,7 @@ class GdrMemoryManager : public RemoteMemoryManager {
RdmaEndpointPtr listening_;
std::atomic<bool> stopped_;
int epfd_;
+ int numa_node_;
// Server side endpoints
// Accessed sequentially in Run() so not protected by lock
@@ -190,46 +184,10 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
port_(port),
listening_(nullptr, EndpointDeleter),
stopped_(true),
- next_key_(0) {
- static std::once_flag flag;
- std::call_once(flag, []() { RegMemVisitors(); });
-}
+ next_key_(0) {}
GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
-/*static*/ void GdrMemoryManager::RegMemVisitors() {
- SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
- size_t num_bytes) {
- GdrMemoryManager::Singleton().InsertMemoryRegion(
- ptr, num_bytes, strings::StrCat("CPU:", numa_node));
- };
- SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
- size_t num_bytes) {
- GdrMemoryManager::Singleton().EvictMemoryRegion(ptr, num_bytes);
- };
- ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
- ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
-
-#if GOOGLE_CUDA
- if (IsGDRAvailable()) {
- int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
-
- // Note we don't free allocated GPU memory so there is no free visitor
- SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
- size_t num_bytes) {
- RdmaMemoryMgr::Singleton().InsertMemoryRegion(
- ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
- };
- GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
- cuda_alloc_visitor);
- GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
- alloc_visitor);
- GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
- LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
- }
-#endif // GOOGLE_CUDA
-}
-
Status GdrMemoryManager::Init() {
epfd_ = epoll_create1(0);
if (epfd_ == -1) {
@@ -289,6 +247,42 @@ Status GdrMemoryManager::Init() {
"cannot add server to epoll");
}
+ numa_node_ = TryToReadNumaNode(listening_->verbs->device);
+
+ SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node,
+ size_t num_bytes) {
+ VLOG(2) << "Registering RDMA capable memory region on numa_node "
+ << numa_node;
+ InsertMemoryRegion(ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+ };
+ SubAllocator::Visitor free_visitor = [this](void* ptr, int numa_node,
+ size_t num_bytes) {
+ VLOG(2) << "De-registering RDMA capable memory region on numa_node "
+ << numa_node;
+ EvictMemoryRegion(ptr, num_bytes);
+ };
+ ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+ ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
+ LOG(INFO) << "Instrumenting CPU allocator(s)";
+
+#if GOOGLE_CUDA
+ if (IsGDRAvailable()) {
+ int bus_id = numa_node_ + 1;
+
+ SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id,
+ size_t num_bytes) {
+ VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id;
+ InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+ };
+ GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+ cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+ alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
+ LOG(INFO) << "Instrumenting GPU allocator(s) with bus_id " << bus_id;
+ }
+#endif // GOOGLE_CUDA
+
return Status::OK();
}
@@ -405,7 +399,7 @@ void GdrMemoryManager::TransportOptionsFromTensor(
ibv_mr* mr = FindMemoryRegion(addr, length);
#if GOOGLE_CUDA
- if (!on_host) {
+ if (device->tensorflow_gpu_device_info() && !on_host) {
Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
GPUUtil::CopyGPUTensorToCPU(
@@ -456,11 +450,27 @@ void GdrMemoryManager::TransportOptionsFromTensor(
#endif
if (mr == nullptr) {
- done(errors::Unavailable("Cannot find pinned memory region"));
- return;
+ Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_);
+ Tensor host_copy(alloc, tensor.dtype(), tensor.shape());
+
+ std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length);
+ VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer";
+
+ buffer = DMAHelper::buffer(&host_copy);
+ addr = buffer->data();
+ length = buffer->size();
+
+ mr = FindMemoryRegion(addr, length);
+ if (mr == nullptr) {
+ done(errors::Unavailable("Cannot find pinned memory region"));
+ return;
+ }
+
+ buffer->Ref();
+ } else {
+ buffer->Ref();
}
- buffer->Ref();
TensorKey tensor_key = next_key_++;
{
mutex_lock l(server_mu_);
@@ -470,7 +480,7 @@ void GdrMemoryManager::TransportOptionsFromTensor(
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
#ifdef GOOGLE_CUDA
- if (!on_host) {
+ if (device->tensorflow_gpu_device_info() && !on_host) {
checksum = GPUUtil::Checksum(device, device_context, tensor);
} else {
checksum = GPUUtil::Checksum(tensor);
@@ -508,7 +518,8 @@ void GdrMemoryManager::TensorFromTransportOptions(
Tensor host_copy;
#if GOOGLE_CUDA
if (mr == nullptr && !on_host) {
- Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
+ Allocator* alloc =
+ GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_);
host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
buffer = DMAHelper::buffer(&host_copy);
addr = buffer->data();
@@ -518,8 +529,18 @@ void GdrMemoryManager::TensorFromTransportOptions(
#endif // GOOGLE_CUDA
if (mr == nullptr) {
- done(errors::Unavailable("Cannot find pinned memory region"));
- return;
+ Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_);
+ host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
+
+ buffer = DMAHelper::buffer(&host_copy);
+ addr = buffer->data();
+ length = buffer->size();
+
+ mr = FindMemoryRegion(addr, length);
+ if (mr == nullptr) {
+ done(errors::Unavailable("Cannot find pinned memory region"));
+ return;
+ }
}
decltype(clients_)::iterator iter;
@@ -568,7 +589,8 @@ void GdrMemoryManager::TensorFromTransportOptions(
}
#if GOOGLE_CUDA
- if (host_copy.NumElements() > 0) {
+ if (device->tensorflow_gpu_device_info() && !on_host &&
+ host_copy.NumElements() > 0) {
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
checksum = GPUUtil::Checksum(host_copy);
@@ -598,6 +620,12 @@ void GdrMemoryManager::TensorFromTransportOptions(
}
#endif // GOOGLE_CUDA
+ if ((on_host || !device->tensorflow_gpu_device_info()) &&
+ host_copy.NumElements() > 0) {
+ std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length);
+ VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer";
+ }
+
uint64_t end = Env::Default()->NowMicros();
VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
@@ -607,7 +635,7 @@ void GdrMemoryManager::TensorFromTransportOptions(
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
#ifdef GOOGLE_CUDA
- if (device->tensorflow_gpu_device_info() && (!on_host)) {
+ if (device->tensorflow_gpu_device_info() && !on_host) {
checksum = GPUUtil::Checksum(device, device_context, *tensor);
} else {
checksum = GPUUtil::Checksum(*tensor);
@@ -668,7 +696,8 @@ ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
}
}
-void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
+void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name) {
if (length == 0) return;
ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
if (mr != nullptr) {
@@ -676,7 +705,8 @@ void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
mrs_.insert(iter, {mr, &MRDeleter});
} else {
- LOG(WARNING) << "Cannot register memory region";
+ LOG(WARNING) << "Cannot register memory region allocated by "
+ << allocator_name;
}
}
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index 44daf7adaa..1e65c3cee2 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -191,6 +191,13 @@ typedef struct {
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
+} TfLiteUnidirectionalSequenceLSTMParams;
+
+typedef struct {
+ // Parameters for the LSTM kernel.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
// If true, store the outputs of both directions in the first output.
bool merge_outputs;
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index eac7db9a88..b092e5ee54 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
@@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto* params =
+ allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>();
+ if (auto* seq_lstm_params =
+ op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(seq_lstm_params->fused_activation_function());
+ params->cell_clip = seq_lstm_params->cell_clip();
+ params->proj_clip = seq_lstm_params->proj_clip();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
auto params =
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png
new file mode 100644
index 0000000000..44d0ccd312
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_accuracy.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png
new file mode 100644
index 0000000000..94a6310612
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/performance/model_size_vs_latency.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 6b7943caf8..ed11452716 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -3,8 +3,15 @@
Mobile and embedded devices have limited computational resources and it is important to keep your application resource efficient. We have compiled a list of best practices and strategies you can use to optimize your model and application when using Tensorflow Lite.
-## Choose the most efficient model for the problem
-Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
+## Choose the best model for the task
+Depending on the task you will need to make a tradeoff between model complexity and size. If your task requires high accuracy then you may need a large and complex model. Some tasks may work with a less precise model, for these tasks it is better to use a smaller but less precise model. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. For example, graphs below show accuracy and latency tradeoff for some common image classification models.
+
+![accuracy vs model size](images/performance/model_size_vs_accuracy.png "Accuracy vs Model size")
+
+
+![latency vs model size](images/performance/model_size_vs_latency.png "Latency vs Model size")
+
+One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for
[image classification](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
@@ -12,25 +19,25 @@ You can retrain the listed models on your own dataset by using transfer learning
## Profile your model
-Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
+Once you have selected a candidate model that is right for your task, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
## Profile and optimize operators in the graph
If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator.
This scenario should be rare as Tensorflow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](custom_operators.md).
## Quantize your model
-If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well.
+If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model.
## Tweak the number of threads
-Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads.
+Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads. Multi-threaded execution however comes at the cost of increased performance variability depending on what else is been executed concurrently. This is particularly the case for mobile apps. For example, isolated tests may show 2x speed up vs single-threaded but if another app is executing at the same time may result in worst performance than single-threaded.
## Eliminate redundant copies
-Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
+If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
## Profile your application with platform specific tools
Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform.
-## Use hardware accelerators available on the device
+## Evaluate whether your model benefits from using hardware accelerators available on the device
Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index b0f32a8d6c..2eb776d10c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -1,6 +1,22 @@
-
# Building TensorFlow on Android
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
To get you started working with TensorFlow on Android, we'll walk through two
ways to build our TensorFlow mobile demos and deploying them on an Android
device. The first is Android Studio, which lets you build and deploy in an
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index 49ad35d4e6..15f0fd3961 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -1,6 +1,22 @@
-
# Overview
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
index be8b4100c8..d922907cdc 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
@@ -1,6 +1,22 @@
-
# Building TensorFlow on iOS
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
## Using CocoaPods
The simplest way to get started with TensorFlow on iOS is using the CocoaPods
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
index 4d4bb3bc08..fd0e322c93 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
@@ -1,6 +1,22 @@
-
# Integrating TensorFlow libraries
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
Once you have made some progress on a model that addresses the problem you’re
trying to solve, it’s important to test it out inside your application
immediately. There are often unexpected differences between your training data
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
index 7436594fd8..59ff8e774c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
@@ -1,6 +1,22 @@
-
# Optimizing for mobile
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
There are some special issues that you have to deal with when you’re trying to
ship on mobile or embedded devices, and you’ll need to think about these as
you’re developing your model.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
index d1c67d4c61..1d373251dd 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
@@ -1,6 +1,22 @@
-
# Preparing models for mobile deployment
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
The requirements for storing model information during training are very
different from when you want to release it as part of a mobile app. This section
covers the tools involved in converting from a training model to something
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index 098ba7e773..e68cd26f81 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary")
load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni")
+JAVA_SRCS = glob([
+ "src/main/java/org/tensorflow/lite/*.java",
+])
+
# Building tensorflow-lite.aar including 4 variants of .so
# To build an aar for release, run below command:
# bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
@@ -20,28 +24,38 @@ aar_with_jni(
android_library = ":tensorflowlite",
)
+# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite.
+aar_with_jni(
+ name = "tensorflow-lite-flex",
+ android_library = ":tensorflowlite_flex",
+)
+
android_library(
name = "tensorflowlite",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflowlite_native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite.
+android_library(
+ name = "tensorflowlite_flex",
+ srcs = JAVA_SRCS,
manifest = "AndroidManifest.xml",
visibility = ["//visibility:public"],
deps = [
- ":tflite_runtime",
+ ":tensorflowlite_native_flex",
"@org_checkerframework_qual",
],
)
android_library(
name = "tensorflowlite_java",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
visibility = ["//visibility:public"],
deps = [
"@org_checkerframework_qual",
@@ -50,16 +64,23 @@ android_library(
java_library(
name = "tensorflowlitelib",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
javacopts = JAVACOPTS,
visibility = ["//visibility:public"],
deps = [
":libtensorflowlite_jni.so",
- "//tensorflow/contrib/lite/java/src/main/native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite.
+java_library(
+ name = "tensorflowlitelib_flex",
+ srcs = JAVA_SRCS,
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtensorflowlite_flex_jni.so",
"@org_checkerframework_qual",
],
)
@@ -72,7 +93,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.TensorFlowLiteTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -87,7 +107,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.DataTypeTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -110,7 +129,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -125,13 +143,13 @@ java_test(
data = [
"src/testdata/add.bin",
"src/testdata/mobilenet.tflite.bin",
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
],
javacopts = JAVACOPTS,
tags = ["no_oss"],
test_class = "org.tensorflow.lite.InterpreterTest",
visibility = ["//visibility:private"],
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -139,6 +157,24 @@ java_test(
)
java_test(
+ name = "InterpreterFlexTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"],
+ data = [
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
+ ],
+ javacopts = JAVACOPTS,
+ tags = ["no_oss"],
+ test_class = "org.tensorflow.lite.InterpreterFlexTest",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":tensorflowlitelib_flex",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
name = "TensorTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"],
@@ -164,14 +200,29 @@ filegroup(
)
cc_library(
- name = "tflite_runtime",
+ name = "tensorflowlite_native",
srcs = ["libtensorflowlite_jni.so"],
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "tensorflowlite_native_flex",
+ srcs = ["libtensorflowlite_flex_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
tflite_jni_binary(
name = "libtensorflowlite_jni.so",
deps = [
"//tensorflow/contrib/lite/java/src/main/native",
],
)
+
+# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite.
+tflite_jni_binary(
+ name = "libtensorflowlite_flex_jni.so",
+ deps = [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index 9d2aead266..360d622b1b 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -30,7 +30,10 @@ EOF
# In some platforms we don't have an Android SDK/NDK and this target
# can't be built. We need to prevent the build system from trying to
# use the target in that case.
- tags = ["manual"],
+ tags = [
+ "manual",
+ "no_cuda_on_cpu_tap",
+ ],
)
native.genrule(
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
index 711638a9f9..d5447b3bf8 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -18,7 +18,8 @@ package org.tensorflow.lite;
/** Static utility methods loading the TensorFlowLite runtime. */
public final class TensorFlowLite {
- private static final String LIBNAME = "tensorflowlite_jni";
+ private static final String PRIMARY_LIBNAME = "tensorflowlite_jni";
+ private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni";
private TensorFlowLite() {}
@@ -29,13 +30,24 @@ public final class TensorFlowLite {
* Load the TensorFlowLite runtime C library.
*/
static boolean init() {
+ Throwable primaryLibException;
try {
- System.loadLibrary(LIBNAME);
+ System.loadLibrary(PRIMARY_LIBNAME);
return true;
} catch (UnsatisfiedLinkError e) {
- System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
- return false;
+ primaryLibException = e;
}
+
+ try {
+ System.loadLibrary(FALLBACK_LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ // If the fallback fails, log the error for the primary load instead.
+ System.err.println(
+ "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage());
+ }
+
+ return false;
}
static {
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
new file mode 100644
index 0000000000..2791c3864b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.io.File;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
+ * have TensorFlow ops.
+ */
+@RunWith(JUnit4.class)
+public final class InterpreterFlexTest {
+
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
+ /** Smoke test validating that flex model loading works when the flex delegate is linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) {
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.run(new float[1], new float[1]);
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index a98fca0132..f8b73c7cf3 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -43,6 +43,9 @@ public final class InterpreterTest {
private static final File MOBILENET_MODEL_FILE =
new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
@Test
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
@@ -345,4 +348,15 @@ public final class InterpreterTest {
interpreter.close();
interpreter.close();
}
+
+ /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try {
+ new Interpreter(FLEX_MODEL_FILE);
+ fail();
+ } catch (IllegalStateException e) {
+ // Expected failure.
+ }
+ }
}
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 68636fb070..d2d8073abd 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -259,6 +259,7 @@ cc_library(
srcs = ["lstm_eval.cc"],
hdrs = ["lstm_eval.h"],
deps = [
+ ":op_macros",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index afb5ec05df..5c9ca6e910 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -49,6 +49,20 @@ cc_library(
],
)
+cc_library(
+ name = "legacy_types",
+ srcs = [],
+ hdrs = [
+ "compatibility.h",
+ "legacy_types.h",
+ "types.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
config_setting(
name = "arm",
values = {
@@ -198,6 +212,7 @@ cc_library(
":strided_slice_logic",
":tensor_utils",
":types",
+ ":legacy_types",
":legacy_reference_base",
":round",
"//third_party/eigen3",
@@ -336,6 +351,7 @@ cc_library(
":quantization_util",
":round",
":strided_slice_logic",
+ ":legacy_types",
":types",
"@gemmlowp",
"//tensorflow/contrib/lite/c:c_api_internal",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/contrib/lite/kernels/internal/legacy_types.h
index 35b4b4e20b..2e4d3137f5 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_options.cc
+++ b/tensorflow/contrib/lite/kernels/internal/legacy_types.h
@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_
-#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
-namespace xla {
-namespace gpu {
+namespace tflite {
-bool ConvUseLayoutHeuristic(const HloModuleConfig& config) {
- return !config.debug_options().xla_backend_extra_options().count(
- "xla_gpu_experimental_conv_disable_layout_heuristic");
-}
+// TODO(b/116772710): Insert legacy Dims<> code in here.
-} // namespace gpu
-} // namespace xla
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index be99240b1f..c8b64cfd96 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/legacy_types.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
-#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
@@ -30,6 +30,11 @@ namespace reference_ops {
static constexpr int kDepthwiseReverseShift = -1;
+inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
+ shape->BuildFrom(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 59f17ae854..19d23fa80b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -100,11 +100,6 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
namespace reference_ops {
-inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
- shape->BuildFrom(
- {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc
index c6c21eb085..20a4e30009 100644
--- a/tensorflow/contrib/lite/kernels/lstm_eval.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
@@ -599,6 +600,7 @@ TfLiteStatus EvalFloat(
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
const int n_batch = input->dims->data[input->dims->size - 2];
const int n_input = input->dims->data[input->dims->size - 1];
@@ -716,6 +718,7 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
TfLiteTensor* output_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
const int n_batch = input->dims->data[input->dims->size - 2];
const int n_input = input->dims->data[input->dims->size - 1];
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index ec9cf38b83..89d57e4599 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -431,7 +431,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params =
+ reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -482,6 +484,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ // Copy out the LSTM specific params so they can be passed in the function.
+ TfLiteLSTMParams lstm_params;
+ lstm_params.activation = params->activation;
+ lstm_params.cell_clip = params->cell_clip;
+ lstm_params.proj_clip = params->proj_clip;
+
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
return lstm_eval::EvalFloat(
@@ -496,7 +504,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, params, /*forward_sequence=*/true,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
output);
}
@@ -523,7 +531,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
/*aux_input_to_cell_weights=*/nullptr,
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, params, /*forward_sequence=*/true,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
/*output_offset=*/0, scratch_buffer, scaling_factors,
prod_scaling_factors, recovered_cell_weights, input_quantized,
/*aux_input_quantized=*/nullptr, activation_state_quantized,
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index cd3aac0532..c97b0fdd61 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -110,11 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
- .Union());
+ SetBuiltinOp(
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions,
+ CreateUnidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
+ .Union());
BuildInterpreter(input_shapes);
}
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index ff8430827c..cb7a282743 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -250,6 +250,7 @@ union BuiltinOptions {
FillOptions,
BidirectionalSequenceLSTMOptions,
BidirectionalSequenceRNNOptions,
+ UnidirectionalSequenceLSTMOptions,
}
enum Padding : byte { SAME, VALID }
@@ -394,6 +395,13 @@ table LSTMOptions {
kernel_type: LSTMKernelType = FULL;
}
+// An implementation of TensorFlow dynamic_rnn with LSTMCell.
+table UnidirectionalSequenceLSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
table BidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index f3cb113c9c..e7b7a59def 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT;
struct LSTMOptions;
struct LSTMOptionsT;
+struct UnidirectionalSequenceLSTMOptions;
+struct UnidirectionalSequenceLSTMOptionsT;
+
struct BidirectionalSequenceLSTMOptions;
struct BidirectionalSequenceLSTMOptionsT;
@@ -681,11 +684,12 @@ enum BuiltinOptions {
BuiltinOptions_FillOptions = 68,
BuiltinOptions_BidirectionalSequenceLSTMOptions = 69,
BuiltinOptions_BidirectionalSequenceRNNOptions = 70,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions
+ BuiltinOptions_MAX = BuiltinOptions_UnidirectionalSequenceLSTMOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[72] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -757,7 +761,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
BuiltinOptions_ZerosLikeOptions,
BuiltinOptions_FillOptions,
BuiltinOptions_BidirectionalSequenceLSTMOptions,
- BuiltinOptions_BidirectionalSequenceRNNOptions
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions
};
return values;
}
@@ -835,6 +840,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"FillOptions",
"BidirectionalSequenceLSTMOptions",
"BidirectionalSequenceRNNOptions",
+ "UnidirectionalSequenceLSTMOptions",
nullptr
};
return names;
@@ -1129,6 +1135,10 @@ template<> struct BuiltinOptionsTraits<BidirectionalSequenceRNNOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions;
};
+template<> struct BuiltinOptionsTraits<UnidirectionalSequenceLSTMOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1720,6 +1730,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
}
+ UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() {
+ return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
+ const UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const {
+ return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<const UnidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -3469,6 +3487,84 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
+ typedef UnidirectionalSequenceLSTMOptions TableType;
+ ActivationFunctionType fused_activation_function;
+ float cell_clip;
+ float proj_clip;
+ UnidirectionalSequenceLSTMOptionsT()
+ : fused_activation_function(ActivationFunctionType_NONE),
+ cell_clip(0.0f),
+ proj_clip(0.0f) {
+ }
+};
+
+struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnidirectionalSequenceLSTMOptionsT NativeTableType;
+ enum {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8
+ };
+ ActivationFunctionType fused_activation_function() const {
+ return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const {
+ return GetField<float>(VT_CELL_CLIP, 0.0f);
+ }
+ float proj_clip() const {
+ return GetField<float>(VT_PROJ_CLIP, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ verifier.EndTable();
+ }
+ UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct UnidirectionalSequenceLSTMOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
+ fbb_.AddElement<int8_t>(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip) {
+ fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip) {
+ fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnidirectionalSequenceLSTMOptionsBuilder &operator=(const UnidirectionalSequenceLSTMOptionsBuilder &);
+ flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnidirectionalSequenceLSTMOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ float cell_clip = 0.0f,
+ float proj_clip = 0.0f) {
+ UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
typedef BidirectionalSequenceLSTMOptions TableType;
ActivationFunctionType fused_activation_function;
@@ -6488,6 +6584,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const {
return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast<const BidirectionalSequenceRNNOptions *>(builtin_options()) : nullptr;
}
+ const UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const {
+ return builtin_options_type() == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast<const UnidirectionalSequenceLSTMOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6799,6 +6898,10 @@ template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_optio
return builtin_options_as_BidirectionalSequenceRNNOptions();
}
+template<> inline const UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as<UnidirectionalSequenceLSTMOptions>() const {
+ return builtin_options_as_UnidirectionalSequenceLSTMOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7809,6 +7912,38 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
_kernel_type);
}
+inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new UnidirectionalSequenceLSTMOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = cell_clip(); _o->cell_clip = _e; };
+ { auto _e = proj_clip(); _o->proj_clip = _e; };
+}
+
+inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateUnidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _fused_activation_function = _o->fused_activation_function;
+ auto _cell_clip = _o->cell_clip;
+ auto _proj_clip = _o->proj_clip;
+ return tflite::CreateUnidirectionalSequenceLSTMOptions(
+ _fbb,
+ _fused_activation_function,
+ _cell_clip,
+ _proj_clip);
+}
+
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BidirectionalSequenceLSTMOptionsT();
UnPackTo(_o, _resolver);
@@ -9620,6 +9755,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9918,6 +10057,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -10204,6 +10347,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value);
return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptionsT *>(value);
+ return CreateUnidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -10490,6 +10637,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ value = new UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10847,6 +10998,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index c698a9567a..5364eebbc9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -27,6 +27,73 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace toco {
+namespace {
+
+// Using the function reducer, reduce input along all axes in axes.
+// Put the reduced data in output, which should aleady be appropriately sized.
+// check_output_shape is set to what this code computes the final shape
+// to be, so it can be cross checked with the shape computation logic.
+void ReduceGeneric(bool keep_dims, const std::vector<int>& axes,
+ const Shape& input_shape, const std::vector<float>& input,
+ Shape* check_output_shape, std::vector<float>* output,
+ const std::function<float(float, float)>& reducer) {
+ if (!IsNonEmpty(input_shape)) {
+ // Zero-dimensions will break the NextIndices() logic, so just early out if
+ // we have an empty shape.
+ return;
+ }
+
+ // Set up output_shape to be the same length as input_shape, with
+ // appropriate dimensions squashed to 1. If keep_dims is false, we'll strip
+ // out the one dimensions at the end, but it's convenient to leave them for
+ // now. We recompute the shape because we need the output shape to have
+ // 1-dims in all the squashed dimensions; the shape from shape computation may
+ // remove those squashed dimensions, depending on the options used.
+ Shape output_shape = input_shape;
+
+ // Reduction mask will be elementwise multiplied against the input
+ // indices to figure out the output index for the element.
+ std::vector<int> reduction_mask(input_shape.dimensions_count(), 1);
+ for (int axis : axes) {
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, input_shape.dimensions_count());
+ reduction_mask[axis] = 0;
+ output_shape.mutable_dims()->at(axis) = 1;
+ }
+
+ std::vector<int> output_indices(input_shape.dimensions_count());
+ for (int input_offset = 0; input_offset < input.size(); ++input_offset) {
+ std::vector<int> input_indices = ReverseOffset(input_shape, input_offset);
+ // Calculate the output location by squashing input indices to 0
+ // in reduced axes.
+ for (int i = 0; i < input_shape.dimensions_count(); ++i) {
+ output_indices[i] = input_indices[i] * reduction_mask[i];
+ }
+ int output_offset = Offset(output_shape, output_indices);
+ if (input_indices == output_indices) {
+ // Base element for the reduced axes
+ output->at(output_offset) = input.at(input_offset);
+ } else {
+ // Reduce with existing element.
+ output->at(output_offset) =
+ reducer(output->at(output_offset), input.at(input_offset));
+ }
+ }
+
+ if (!keep_dims) {
+ // Strip out the dims from output_shape.
+ std::vector<int> new_dims;
+ for (int i = 0; i < output_shape.dimensions_count(); ++i) {
+ if (reduction_mask[i]) {
+ new_dims.push_back(output_shape.dims(i));
+ }
+ }
+ output_shape.mutable_dims()->swap(new_dims);
+ }
+ *check_output_shape = output_shape;
+}
+
+} // namespace
bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
auto& output_array = model->GetArray(op.outputs[0]);
@@ -176,27 +243,19 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
- int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
- CHECK_LT(axis, input_shape.dimensions_count()) << "Axis out of bounds";
- // We currently only handle reduction on axis 0.
- CHECK_EQ(axis, 0) << "Only reduction along axis 0 is supported";
- // We currently only handle 1-D and 2-D input tensors.
- CHECK_LE(input_shape.dimensions_count(), 2) << "Rank >2 not yet supported";
// We only support keep_dims=true; shape prop will need to change otherwise.
auto sum_op = static_cast<const TensorFlowSumOperator*>(unary_op);
- CHECK(sum_op->keep_dims) << "Only keep_dims=true is supported";
+ Shape check_output_shape;
- std::vector<int> indices(input_shape.dimensions_count());
- for (int i = 0; i < input_shape.dims(1); ++i) {
- indices[1] = i;
- float sum = 0.f;
- for (int j = 0; j < input_shape.dims(0); ++j) {
- indices[0] = j;
- sum += (*input_float_data)[Offset(input_shape, indices)];
- }
- output_float_data[i] = sum;
- }
+ ReduceGeneric(
+ sum_op->keep_dims, axis_array.GetBuffer<ArrayDataType::kInt32>().data,
+ input_shape, *input_float_data, &check_output_shape, &output_float_data,
+ [](float existing, float current) -> float {
+ return existing + current;
+ });
+ CHECK(check_output_shape == output_shape)
+ << "Shape propagation output shape doesn't match output shape from op";
} else if (unary_op->type == OperatorType::kReduceMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index acf1e3ede5..6f1be298ca 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -30,3 +30,16 @@ tf_cc_test(
"@com_google_googletest//:gtest_main",
],
)
+
+tf_cc_test(
+ name = "resolve_constant_unary_test",
+ srcs = ["resolve_constant_unary_test.cc"],
+ tags = ["no_oss"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
new file mode 100644
index 0000000000..a53abc9941
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <tuple>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+void RunResolveSum(const std::vector<float>& input,
+ const std::vector<int>& input_shape,
+ const std::vector<int>& axis,
+ const std::vector<int>& output_shape,
+ const std::vector<float>& expected_output) {
+ Model model;
+ Array& input0 = model.GetOrCreateArray("input0");
+ Array& input1 = model.GetOrCreateArray("input1");
+ Array& output = model.GetOrCreateArray("output");
+
+ *input0.mutable_shape()->mutable_dims() = input_shape;
+ input0.data_type = ArrayDataType::kFloat;
+ input0.GetMutableBuffer<ArrayDataType::kFloat>().data = input;
+
+ *input1.mutable_shape()->mutable_dims() = {static_cast<int>(axis.size())};
+ input1.GetMutableBuffer<ArrayDataType::kInt32>().data = axis;
+ input1.data_type = ArrayDataType::kInt32;
+
+ *output.mutable_shape()->mutable_dims() = output_shape;
+
+ auto sum_op = absl::make_unique<TensorFlowSumOperator>();
+ sum_op->keep_dims = true;
+ sum_op->inputs = {"input0", "input1"};
+ sum_op->outputs = {"output"};
+ model.operators.push_back(std::move(sum_op));
+ ResolveConstantUnaryOperator().Run(&model, 0);
+ EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data,
+ expected_output);
+ EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape);
+}
+
+// Reduce a 2d array across axis 0
+TEST(ResolveConstantUnary, ResolveSumAxis0_2D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ {3, 1, 4, 1,
+ 5, 9, 2, 6,
+ 5, 3, 5, 8},
+
+ // Input shape
+ {3, 4},
+
+ // Axes
+ {0},
+
+ // Expected output shape,
+ {1, 4},
+
+ // Expected output
+ {13, 13, 11, 15});
+ // clang-format on
+}
+
+// Reduce a 2d array across axis 1
+TEST(ResolveConstantUnary, ResolveSumAxis1_2D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ {3, 1, 4, 1,
+ 5, 9, 2, 6,
+ 5, 3, 5, 8},
+
+ // Input shape
+ {3, 4},
+
+ // Axes
+ {1},
+
+ // Expected output shape,
+ {3, 1},
+
+ // Expected output
+ {9, 22, 21});
+ // clang-format on
+}
+
+// Reduce a 3d tensor across axes 0 and 2.
+TEST(ResolveConstantUnary, ResolveSumAxis0_2_3D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ { 0, 1, 2,
+ 3, 10, 11,
+ 12, 13, 20,
+ 21, 22, 23,
+
+ 100, 101, 102,
+ 103, 110, 111,
+ 112, 113, 120,
+ 121, 122, 123,
+
+ 200, 201, 202,
+ 203, 210, 211,
+ 212, 213, 220,
+ 221, 222, 223 },
+
+ // Input shape
+ {3, 4, 3},
+
+ // Axes
+ {0, 2},
+
+ // Expected output shape,
+ {1, 4, 1},
+
+ // Expected output, generated using octave.
+ { 909, 972, 1035, 1098});
+ // clang-format on
+}
+
+} // namespace
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5eaf6e27fc..133ef79a34 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -477,6 +477,30 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
+// Retain TensorFlow NodeDef in Toco Operator.
+//
+// If an op is supported by Toco but not supported by TFLite, TFLite exporter
+// will use the retained NodeDef to populate a Flex op when Flex mode is
+// enabled.
+//
+// This can't be easily applied to all operations, because a TensorFlow node
+// may become multiple Toco operators. Thus we need to call this function in
+// operator conversion functions one by one whenever feasible.
+//
+// This may cause problems if a graph transformation rule changes parameters
+// of the node. When calling this function, please check if any existing
+// graph transformation rule will change an existing operator with the same
+// type.
+//
+// This provides a route to handle Toco-supported & TFLite-unsupported ops
+// in Flex mode. However it's not a solid solution. Eventually we should
+// get rid of this.
+// TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
+// this function.
+void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
+ node.SerializeToString(&op->tensorflow_node_def);
+}
+
tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -990,6 +1014,10 @@ tensorflow::Status ConvertBatchMatMulOperator(
auto* batch_matmul = new BatchMatMulOperator;
batch_matmul->inputs = {node.input(0), node.input(1)};
batch_matmul->outputs = {node.name()};
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, batch_matmul);
+
model->operators.emplace_back(batch_matmul);
return tensorflow::Status::OK();
}
@@ -1081,7 +1109,10 @@ tensorflow::Status ConvertUnsupportedOperator(
auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = node.op();
- node.SerializeToString(&op->tensorflow_node_def);
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
// Parse inputs.
@@ -1605,6 +1636,10 @@ tensorflow::Status ConvertRangeOperator(
op->inputs.push_back(node.input(1));
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 6e207fdf54..61f1f095e9 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -376,6 +376,13 @@ struct Operator {
// looks unused.
bool unresolved_outputs = false;
+ // A serialized tensorflow::NodeDef string.
+ // The field is filled only when importing from TensorFlow.
+ // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`.
+ // It's not guaranteed to be filled for other ops. Ops created by graph
+ // transformations won't have TensorFlow NodeDef.
+ string tensorflow_node_def;
+
protected:
// Constructor used by subclasses for specific OperatorType's.
explicit Operator(OperatorType t)
@@ -1535,8 +1542,6 @@ struct TensorFlowUnsupportedOperator : Operator {
// The original TF operation type. Used for diagnostic purposes.
string tensorflow_op;
- // A serialized tensorflow::NodeDef string.
- string tensorflow_node_def;
// A boolean indicating if the unsupported op should be treated as quantized.
bool quantized = false;
// A boolean indicating if the unsupported op output should allow float values
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index f6f76e48a4..3b34cd6285 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -95,11 +95,13 @@ OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
bool allow_flex_ops) {
+ // Get the op name (by Toco definition).
string name = HelpfulOperatorTypeName(op);
- const auto& builtin_ops = GetBuiltinOpsMap();
bool is_builtin = false;
OperatorKey key;
+
+ const auto& builtin_ops = GetBuiltinOpsMap();
if (ops_by_type.count(op.type) != 0) {
key.version = ops_by_type.at(op.type)->GetVersion(op);
name = ops_by_type.at(op.type)->name();
@@ -110,37 +112,46 @@ OperatorKey GetOperatorKey(
// For TFLite supported builtin ops, find out its BuiltinOperator enum used
// in FlatBuffer.
key.type = builtin_ops.at(name);
- } else {
- key.type = BuiltinOperator_CUSTOM;
-
- key.is_custom_op = true;
- if (op.type == OperatorType::kUnsupported) {
- const TensorFlowUnsupportedOperator& unsupported_op =
- static_cast<const TensorFlowUnsupportedOperator&>(op);
- const auto tensorflow_op = unsupported_op.tensorflow_op;
-
- // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
- // to populate a regular custom op. We need to find a way to fix this.
- if (allow_flex_ops) {
- // Memorize the original TensorFlow op name.
- key.flex_tensorflow_op = tensorflow_op;
- // Prefix the custom code of the flex op.
- key.custom_code =
- string(::tflite::kFlexCustomCodePrefix) + tensorflow_op;
- key.is_flex_op = true;
-
- if (IsControlFlowOp(tensorflow_op)) {
- key.is_unsupported_flex_op = true;
- }
- } else {
- key.custom_code = tensorflow_op;
- }
+ return key;
+ }
+
+ // The logic below is all for custom ops.
+ key.is_custom_op = true;
+ key.type = BuiltinOperator_CUSTOM;
+
+ if (op.type == OperatorType::kUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ const auto tensorflow_op = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_flex_ops) {
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = tensorflow_op;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
} else {
- // For Toco-supported/TFLite-unsupported ops, currently we produce a
- // custom op. This gives developers a chance to implement custom ops.
- // TODO(b/116800229): Also produce Toco-supported/TFLite-unsupported ops
- // as Flex ops when Flex mode is enabled.
- key.custom_code = name;
+ key.custom_code = tensorflow_op;
+ }
+ } else if (allow_flex_ops && !op.tensorflow_node_def.empty()) {
+ // For Toco-supported/TFLite-unsupported ops, if the TensorFlow NodeDef
+ // is retained in the Toco Operator, we produce a Flex op if Flex mode
+ // is enabled.
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = name;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
+ } else {
+ // If Flex is disabled or the original TensorFlow NodeDef isn't available,
+ // we produce a custom op. This gives developers a chance to implemenr
+ // custom ops.
+ key.custom_code = name;
+ }
+
+ if (key.is_flex_op) {
+ if (IsControlFlowOp(key.flex_tensorflow_op)) {
+ key.is_unsupported_flex_op = true;
}
}
return key;
@@ -323,8 +334,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(
- details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
+ const auto key =
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ int op_index = operators_map.at(key);
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -349,6 +361,11 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
variable_tensor_indices->insert(variable_tensor_index);
}
}
+ } else if (key.is_flex_op && !op->tensorflow_node_def.empty()) {
+ auto fbb = WriteFlexOpOptions(op->tensorflow_node_def);
+ if (fbb) {
+ options = Options::Custom(builder->CreateVector(fbb->GetBuffer()));
+ }
}
// The only supported CustomOptionFormat is FLEXBUFFERS now.
op_vector.push_back(CreateOperator(
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index d48ab78285..eda1aa78a3 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/core/framework/node_def.pb.h"
namespace toco {
namespace tflite {
@@ -382,6 +383,39 @@ TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
EXPECT_TRUE(key.is_unsupported_flex_op);
}
+TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
+ // Test Toco-supported/TFLite-unsupported operators.
+ // TODO(ycling): The test will be broken if Range is implemented in TFLite.
+ // Find a more robust way to test the fallback logic.
+ auto op = absl::make_unique<RangeOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+
+ {
+ // If NodeDef isn't retained in the Toco op, a regular custom op
+ // will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "Range");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_FALSE(key.is_flex_op);
+ }
+
+ ::tensorflow::NodeDef node_def;
+ node_def.set_name("Range");
+ node_def.set_op("Range");
+ node_def.SerializeToString(&op->tensorflow_node_def);
+
+ {
+ // If NodeDef is retained in the Toco op, a Flex op will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexRange");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ }
+}
+
// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
} // namespace
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 9addbb81e7..ed37535fe0 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1157,6 +1157,25 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def) {
+ auto fbb = absl::make_unique<flexbuffers::Builder>();
+
+ ::tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(tensorflow_node_def)) {
+ LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
+ return {};
+ }
+
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing flex op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+}
+
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const string& name, OperatorType type,
@@ -1192,6 +1211,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<flexbuffers::Builder> WriteOptions(
const TensorFlowUnsupportedOperator& op) const {
+ if (allow_flex_ops_) {
+ return WriteFlexOpOptions(op.tensorflow_node_def);
+ }
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
@@ -1200,16 +1222,6 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
- if (allow_flex_ops_) {
- fbb->Vector([&]() {
- fbb->String(node_def.op());
- fbb->String(op.tensorflow_node_def);
- });
- fbb->Finish();
- LOG(INFO) << "Writing flex op: " << node_def.op();
- return std::unique_ptr<flexbuffers::Builder>(fbb.release());
- }
-
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 13d9f6c49a..6e4e0a16d1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -36,6 +37,11 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool allow_flex_ops = false);
+// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
+// for a Flex op.
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def);
+
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
using BuiltinOptions = void;
diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD
index 3ba3ee29ec..2cf445a85e 100644
--- a/tensorflow/contrib/optimizer_v2/BUILD
+++ b/tensorflow/contrib/optimizer_v2/BUILD
@@ -47,15 +47,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
+ "//tensorflow/python:util",
+ "//tensorflow/python/keras:optimizer_v2",
],
)
diff --git a/tensorflow/contrib/optimizer_v2/adadelta.py b/tensorflow/contrib/optimizer_v2/adadelta.py
index b206f9f61b..9d73bddd1c 100644
--- a/tensorflow/contrib/optimizer_v2/adadelta.py
+++ b/tensorflow/contrib/optimizer_v2/adadelta.py
@@ -18,17 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.util import deprecation
-class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
+class AdadeltaOptimizer(adadelta.Adadelta):
"""Optimizer that implements the Adadelta algorithm.
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8,
use_locking=False, name="Adadelta"):
"""Construct a new Adadelta optimizer.
@@ -48,66 +52,5 @@ class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
"""
- super(AdadeltaOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("rho", rho)
- self._set_hyper("epsilon", epsilon)
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "accum")
- state.zeros_slot(v, "accum_update")
-
- def _apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.sparse_apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_sparse_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdadeltaOptimizer, self).__init__(
+ learning_rate=learning_rate, rho=rho, epsilon=epsilon, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index dab1e02716..716361e29c 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -18,15 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.util import deprecation
-class AdagradOptimizer(optimizer_v2.OptimizerV2):
+class AdagradOptimizer(adagrad.Adagrad):
"""Optimizer that implements the Adagrad algorithm.
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
@@ -34,6 +30,10 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
[intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, initial_accumulator_value=0.1,
use_locking=False, name="Adagrad"):
"""Construct a new Adagrad optimizer.
@@ -54,64 +54,7 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
Raises:
ValueError: If the `initial_accumulator_value` is invalid.
"""
- if initial_accumulator_value <= 0.0:
- raise ValueError("initial_accumulator_value must be positive: %s" %
- initial_accumulator_value)
- super(AdagradOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- self._initial_accumulator_value = initial_accumulator_value
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- def init(v=v, dtype=dtype):
- # Use a Tensor instead of initializer if variable does not have
- # static shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- return math_ops.cast(init_constant, dtype)
- state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
- "accumulator")
-
- def _apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.sparse_apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_sparse_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdagradOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ initial_accumulator_value=initial_accumulator_value,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad_test.py b/tensorflow/contrib/optimizer_v2/adagrad_test.py
index debaaaeeba..320e41567f 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad_test.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad_test.py
@@ -68,9 +68,6 @@ class AdagradOptimizerTest(test.TestCase):
def testBasicResource(self):
self.doTestBasic(use_locking=False, use_resource=True)
- def testBasicLocked(self):
- self.doTestBasic(use_locking=True)
-
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index 04b1552b61..363e020757 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -18,22 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.util import deprecation
-class AdamOptimizer(optimizer_v2.OptimizerV2):
+class AdamOptimizer(adam.Adam):
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
@@ -87,111 +86,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
"""
- super(AdamOptimizer, self).__init__(use_locking, name)
-
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("beta1", beta1)
- self._set_hyper("beta2", beta2)
- self._set_hyper("epsilon", epsilon)
-
- def _get_beta_accumulators(self, state=None):
- if state is None:
- state = self._get_per_graph_state()
- return (state.get_non_slot("beta1_power"),
- state.get_non_slot("beta2_power"))
-
- def _create_vars(self, var_list, state):
- # Non-slot variables end up on the same device(s).
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"),
- name="beta1_power")
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"),
- name="beta2_power")
-
- # Create slots for the first and second moments.
- for v in var_list:
- state.zeros_slot(v, "m")
- state.zeros_slot(v, "v")
-
- def _apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.apply_adam(
- var, m, v,
- math_ops.cast(beta1_power, var.dtype.base_dtype),
- math_ops.cast(beta2_power, var.dtype.base_dtype),
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("beta1", var.dtype.base_dtype),
- state.get_hyper("beta2", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad, use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.resource_apply_adam(
- var.handle, m.handle, v.handle,
- math_ops.cast(beta1_power, grad.dtype.base_dtype),
- math_ops.cast(beta2_power, grad.dtype.base_dtype),
- state.get_hyper("learning_rate", grad.dtype.base_dtype),
- state.get_hyper("beta1", grad.dtype.base_dtype),
- state.get_hyper("beta2", grad.dtype.base_dtype),
- state.get_hyper("epsilon", grad.dtype.base_dtype),
- grad, use_locking=self._use_locking)
-
- def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
- lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
- beta1_t = state.get_hyper("beta1", var.dtype.base_dtype)
- beta2_t = state.get_hyper("beta2", var.dtype.base_dtype)
- epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
- lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
- # m_t = beta1 * m + (1 - beta1) * g_t
- m = state.get_slot(var, "m")
- m_scaled_g_values = grad * (1 - beta1_t)
- m_t = state_ops.assign(m, m * beta1_t,
- use_locking=self._use_locking)
- with ops.control_dependencies([m_t]):
- m_t = scatter_add(m, indices, m_scaled_g_values)
- # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
- v = state.get_slot(var, "v")
- v_scaled_g_values = (grad * grad) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- with ops.control_dependencies([v_t]):
- v_t = scatter_add(v, indices, v_scaled_g_values)
- v_sqrt = math_ops.sqrt(v_t)
- var_update = state_ops.assign_sub(var,
- lr * m_t / (v_sqrt + epsilon_t),
- use_locking=self._use_locking)
- return control_flow_ops.group(*[var_update, m_t, v_t])
-
- def _apply_sparse(self, grad, var, state):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
- x, i, v, use_locking=self._use_locking),
- state)
-
- def _resource_scatter_add(self, x, i, v):
- with ops.control_dependencies(
- [resource_variable_ops.resource_scatter_add(
- x.handle, i, v)]):
- return x.value()
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- return self._apply_sparse_shared(
- grad, var, indices, self._resource_scatter_add, state)
-
- def _finish(self, state):
- # Update the power accumulators.
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- update_beta1 = beta1_power.assign(
- beta1_power * state.get_hyper("beta1"),
- use_locking=self._use_locking)
- update_beta2 = beta2_power.assign(
- beta2_power * state.get_hyper("beta2"),
- use_locking=self._use_locking)
- return control_flow_ops.group(update_beta1, update_beta2)
+ super(AdamOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ beta_1=beta1,
+ beta_2=beta2,
+ epsilon=epsilon,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index e13b82d1d2..3c68ef995a 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -130,8 +130,8 @@ class CheckpointingTests(test.TestCase):
# non-Layer dependency of the model
"model/_non_layer/a_variable",
# The optimizer creates two non-slot variables
- "optimizer/beta1_power",
- "optimizer/beta2_power",
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
# Slot variables
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
@@ -161,21 +161,20 @@ class CheckpointingTests(test.TestCase):
"my_model/dense/kernel",
named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power",
- named_variables["optimizer/beta1_power" + suffix].full_name)
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power",
- named_variables["optimizer/beta2_power" + suffix].full_name)
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
1].node_id]
- self.assertEqual("beta1_power",
- optimizer_node.children[0].local_name)
- self.assertEqual("beta1_power",
- serialized_graph.nodes[optimizer_node.children[0].node_id]
- .attributes[0].full_name)
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
self.assertEqual(
"my_model/dense/kernel",
serialized_graph.nodes[optimizer_node.slot_variables[0]
@@ -241,9 +240,10 @@ class CheckpointingTests(test.TestCase):
on_create_model = MyModel()
on_create_optimizer = adam.AdamOptimizer(
0.001,
- # Preserve beta1_power and beta2_power when appying gradients so we can
- # test that they've been restored correctly.
- beta1=1.0, beta2=1.0)
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta1=1.0,
+ beta2=1.0)
on_create_root = util.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
@@ -263,9 +263,9 @@ class CheckpointingTests(test.TestCase):
dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value)
status.assert_consumed()
- beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
- self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
- self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
# TODO(allenl): Debug garbage created by this test in python3.
def testDeferredRestorationUsageEager(self):
@@ -477,7 +477,7 @@ class CheckpointingTests(test.TestCase):
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = adam.AdamOptimizer(0.1)
- with self.assertRaisesRegexp(AssertionError, "beta1_power"):
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.executing_eagerly():
@@ -556,8 +556,8 @@ class CheckpointingTests(test.TestCase):
self.evaluate(first_variable.assign([1.]))
self.evaluate(optimizer.get_slot(
var=first_variable, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
# Save and load in a second graph
second_graph = ops.Graph()
@@ -571,29 +571,29 @@ class CheckpointingTests(test.TestCase):
self.evaluate(second_variable.assign([4.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([5.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(6.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
save_path = second_root_checkpointable.save(checkpoint_prefix)
self.evaluate(second_variable.assign([7.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([8.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
status = second_root_checkpointable.restore(save_path)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([4.], self.evaluate(second_variable))
self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
var=second_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
# Check that the first graph is unmolested
with first_graph.as_default(), first_session.as_default():
self.assertAllEqual([1.], self.evaluate(first_variable))
self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
var=first_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
class TemplateTests(test.TestCase):
@@ -659,8 +659,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.evaluate(model._named_dense.bias.assign([1.]))
self.evaluate(optimizer.get_slot(
var=model._named_dense.bias, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
return root_checkpointable
def _set_sentinels(self, root_checkpointable):
@@ -669,8 +669,8 @@ class CheckpointCompatibilityTests(test.TestCase):
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")
.assign([102.]))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(103.))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
def _check_sentinels(self, root_checkpointable):
self.assertAllEqual(
@@ -678,8 +678,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.assertAllEqual([2.], self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
def _write_name_based_checkpoint(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent.py b/tensorflow/contrib/optimizer_v2/gradient_descent.py
index 945c8de559..8bdf408217 100644
--- a/tensorflow/contrib/optimizer_v2/gradient_descent.py
+++ b/tensorflow/contrib/optimizer_v2/gradient_descent.py
@@ -18,15 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
+class GradientDescentOptimizer(sgd.SGD):
"""Optimizer that implements the gradient descent algorithm."""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
"""Construct a new gradient descent optimizer.
@@ -41,29 +43,5 @@ class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "GradientDescent".
"""
- super(GradientDescentOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- def _apply_dense(self, grad, var, state):
- return training_ops.apply_gradient_descent(
- var,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, handle, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return training_ops.resource_apply_gradient_descent(
- handle.handle, lr, grad, use_locking=self._use_locking)
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return resource_variable_ops.resource_scatter_add(
- handle.handle, indices, -grad * lr)
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- delta = ops.IndexedSlices(
- grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.indices, grad.dense_shape)
- return var.scatter_sub(delta, use_locking=self._use_locking)
+ super(GradientDescentOptimizer, self).__init__(
+ learning_rate=learning_rate, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py
index 0a5aadc2d1..0636f7e356 100644
--- a/tensorflow/contrib/optimizer_v2/momentum.py
+++ b/tensorflow/contrib/optimizer_v2/momentum.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class MomentumOptimizer(optimizer_v2.OptimizerV2):
+class MomentumOptimizer(sgd.SGD):
"""Optimizer that implements the Momentum algorithm.
Computes (if `use_nesterov = False`):
@@ -39,6 +39,10 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
when that part of the variable was used in the forward pass.
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, momentum,
use_locking=False, name="Momentum", use_nesterov=False):
"""Construct a new Momentum optimizer.
@@ -68,57 +72,8 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
optimizer functions.
@end_compatibility
"""
- super(MomentumOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("momentum", momentum)
- self._use_nesterov = use_nesterov
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
-
- def _apply_sparse(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.sparse_apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_sparse_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
+ super(MomentumOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ momentum=momentum,
+ name=name,
+ nesterov=use_nesterov)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 53e27c08c4..9c98dd93b4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -20,462 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import abc
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.util import deprecation
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import distribution_strategy_context
-from tensorflow.python.training import optimizer as optimizer_v1
-from tensorflow.python.training import slot_creator
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.util import nest
-
-class _OptimizableVariable(object):
- """Interface for abstracting over variables in the optimizers."""
-
- @abc.abstractmethod
- def target(self):
- """Returns the optimization target for this variable."""
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def update_op(self, optimizer, g, *args):
- """Returns the update ops for updating the variable."""
- raise NotImplementedError("Calling an abstract method.")
-
-
-class _RefVariableProcessor(_OptimizableVariable):
- """Processor for Variable."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v._ref() # pylint: disable=protected-access
-
- def update_op(self, optimizer, g, *args):
- if isinstance(g, ops.Tensor):
- update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
- else:
- assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
- "tensor nor IndexedSlices.")
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- # pylint: disable=protected-access
- return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
-
-
-class _DenseReadResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _DenseResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- if isinstance(g, ops.IndexedSlices):
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- return optimizer._resource_apply_sparse_duplicate_indices(
- g.values, self._v, g.indices, *args)
- update_op = optimizer._resource_apply_dense(g, self._v, *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _TensorProcessor(_OptimizableVariable):
- """Processor for ordinary Tensors.
-
- Even though a Tensor can't really be updated, sometimes it is useful to
- compute the gradients with respect to a Tensor using the optimizer. Updating
- the Tensor is, of course, unsupported.
- """
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- raise NotImplementedError("Trying to update a Tensor ", self._v)
-
-
-def _get_processor(v):
- """The processor of v."""
- if context.executing_eagerly():
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- else:
- return _DenseResourceVariableProcessor(v)
- if v.op.type == "VarHandleOp":
- return _DenseResourceVariableProcessor(v)
- if isinstance(v, variables.Variable):
- return _RefVariableProcessor(v)
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- raise NotImplementedError("Trying to optimize unsupported type ", v)
-
-
-def _var_key_v2(var):
- """Key for representing a primary variable, for looking up slots."""
- # pylint: disable=protected-access
- if hasattr(var, "_distributed_container"):
- distributed_container = var._distributed_container()
- assert distributed_container is not None
- if context.executing_eagerly():
- return distributed_container._unique_id
- return distributed_container._shared_name
- if context.executing_eagerly():
- return var._unique_id
- return var.op.name
-
-
-def _resolve(value, name):
- if callable(value):
- value = value()
- return ops.convert_to_tensor(value, name=name)
-
-
-def _is_dynamic(value):
- """Returns true if __init__ arg `value` should be re-evaluated each step."""
- if callable(value): return True
- # Don't need to do anything special in graph mode, since dynamic values
- # will propagate correctly automatically.
- # TODO(josh11b): Add per-device caching across steps using variables for
- # truly static values once we add distributed support.
- if context.executing_eagerly() and isinstance(
- value, resource_variable_ops.ResourceVariable):
- return True
- return False
-
-
-class _OptimizerV2State(object):
- """Holds per-graph and per-step optimizer state.
-
- Use _init_with_static_hyper() to create the state for a graph, and then
- _copy_with_dynamic_hyper() to convert that to state for a particular step.
- The difference between the two is that the former only has hyper
- parameter values that are static and the latter also has values that
- can change every step (according to _is_dynamic()).
- """
-
- def __init__(self, op_name):
- self._op_name = op_name
-
- def _init_with_static_hyper(self, hyper):
- """Initialize a fresh state object from hyper dict."""
- # self._hyper contains a dict from name to a dict with the Tensor values.
- # This dict starts with a single item with key "None" with the hyper
- # parameter value converted to a Tensor. Other items have dtype keys
- # with that Tensor cast to that dtype.
- with ops.init_scope():
- self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if not dynamic}
- self._slots = {}
- self._non_slot_dict = {}
- # Extra state to help Optimizers implement Checkpointable. Holds information
- # about variables which will be restored as soon as they're created.
- self._deferred_dependencies = {} # Non-slot variables
- self._deferred_slot_restorations = {} # Slot variables
-
- def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
- """Create a new state object for a particular step."""
- ret = _OptimizerV2State(self._op_name)
- # pylint: disable=protected-access
- ret._slots = self._slots
- ret._non_slot_dict = self._non_slot_dict
- ret._deferred_dependencies = self._deferred_dependencies
- ret._deferred_slot_restorations = self._deferred_slot_restorations
- ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if dynamic}
- ret._hyper.update(self._hyper)
- ret._non_slot_devices = non_slot_devices
- ret._distribution = distribution
- return ret
-
- def _variables(self):
- """Returns a list of all variables held by self."""
- optimizer_variables = list(self._non_slot_dict.values())
- for variable_dict in self._slots.values():
- for slot_for_variable in variable_dict.values():
- optimizer_variables.append(slot_for_variable)
- # Sort variables by name so that the return is deterministic.
- return sorted(optimizer_variables, key=lambda v: v.name)
-
- def _slot_dict(self, slot_name):
- """Returns a dict for caching slots created under the given name.
-
- Args:
- slot_name: Name for the slot.
-
- Returns:
- A dict that maps primary `Variable` objects to the slot created
- for that variable, under the given slot name.
- """
- named_slots = self._slots.get(slot_name, None)
- if named_slots is None:
- named_slots = {}
- self._slots[slot_name] = named_slots
- return named_slots
-
- def create_slot(self, var, val, slot_name, optional_op_name=None):
- """Find or create a slot for a variable.
-
- Args:
- var: A `Variable` object.
- val: A `Tensor`. The initial value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot(
- var, val, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def create_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, optional_op_name=None):
- """Find or create a slot for a variable, using an Initializer.
-
- Args:
- var: A `Variable` object.
- initializer: An `Initializer`. The initial value of the slot.
- shape: Shape of the initial value of the slot.
- dtype: Type of the value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot_with_initializer(
- var, initializer, shape, dtype, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def zeros_slot(self, var, slot_name, optional_op_name=None):
- """Find or create a slot initialized with 0.0.
-
- Args:
- var: A `Variable` object.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_zeros_slot(
- var, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable,
- optional_op_name=None):
- """Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored. When executing eagerly, we create the slot variable with a
- restoring initializer.
-
- No new variables are created when graph building. Instead,
- _restore_slot_variable catches these after normal creation and adds restore
- ops to the graph. This method is nonetheless important when graph building
- for the case when a slot variable has already been created but `variable`
- has just been added to a dependency graph (causing us to realize that the
- slot variable needs to be restored).
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
- """
- slot_variable = self.get_slot(var=variable, name=slot_name)
- if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()
- # Defer slot variable creation if there is an active variable creator
- # scope. Generally we'd like to eagerly create/restore slot variables
- # when possible, but this may mean that scopes intended to catch
- # `variable` also catch its eagerly created slot variable
- # unintentionally (specifically make_template would add a dependency on
- # a slot variable if not for this case). Deferring is mostly harmless
- # (aside from double initialization), and makes variable creator scopes
- # behave the same way they do when graph building.
- and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
- initializer = checkpointable.CheckpointInitialValue(
- checkpoint_position=slot_variable_position)
- slot_variable = self.create_slot(
- var=variable,
- val=initializer,
- slot_name=slot_name,
- optional_op_name=optional_op_name)
- # Optimizers do not have unconditional dependencies on their slot
- # variables (nor do any other objects). They are only saved if the
- # variables they were created for are also saved.
- if slot_variable is not None:
- # If we've either made this slot variable, or if we've pulled out an
- # existing slot variable, we should restore it.
- slot_variable_position.restore(slot_variable)
- else:
- # We didn't make the slot variable. Defer restoring until it gets created
- # normally. We keep a list rather than the one with the highest restore
- # UID in case slot variables have their own dependencies, in which case
- # those could differ between restores.
- variable_key = _var_key_v2(variable)
- self._deferred_slot_restorations.setdefault(
- slot_name, {}).setdefault(variable_key, []).append(
- slot_variable_position)
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- named_slots = self._slots.get(name, None)
- if not named_slots:
- return None
- return named_slots.get(_var_key_v2(var), None)
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- return sorted(self._slots.keys())
-
- def create_non_slot(self, initial_value, name, colocate_with=None):
- """Add an extra variable, not associated with a slot."""
- v = self._non_slot_dict.get(name, None)
- if v is None:
- if colocate_with is None: colocate_with = self._non_slot_devices
- with self._distribution.colocate_vars_with(colocate_with):
- # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
- v = variable_scope.variable(initial_value, name=name, trainable=False)
- self._non_slot_dict[name] = v
- deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
- for checkpoint_position in sorted(
- deferred_dependencies_list,
- key=lambda restore: restore.checkpoint.restore_uid,
- reverse=True):
- checkpoint_position.restore(v)
- return v
-
- def _restore_slot_variable(self, slot_name, variable, slot_variable):
- """Restore a newly created slot variable's value."""
- variable_key = _var_key_v2(variable)
- deferred_restorations = self._deferred_slot_restorations.get(
- slot_name, {}).pop(variable_key, [])
- # Iterate over restores, highest restore UID first to minimize the number
- # of assignments.
- deferred_restorations.sort(key=lambda position: position.restore_uid,
- reverse=True)
- for checkpoint_position in deferred_restorations:
- checkpoint_position.restore(slot_variable)
-
- def get_non_slot(self, name):
- """Returns the non-slot variable identified by `name`."""
- return self._non_slot_dict.get(name, None)
-
- def get_hyper(self, name, dtype=None):
- """Returns the `name` hyper parameter, optionally cast to `dtype`."""
- dtype_dict = self._hyper[name]
- # Do we have the value cast to dtype already cached? This should always
- # succeed when dtype is None.
- if dtype in dtype_dict:
- return dtype_dict[dtype]
- # Not cached, cast to dtype and save the result in the cache.
- result = math_ops.cast(dtype_dict[None], dtype)
- dtype_dict[dtype] = result
- return result
-
-
-class OptimizerV2(optimizer_v1.Optimizer):
+class OptimizerV2(optimizer_v2.OptimizerV2):
"""Updated base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
@@ -586,6 +135,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
GATE_OP = 1
GATE_GRAPH = 2
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, use_locking, name):
"""Create a new Optimizer.
@@ -606,746 +159,4 @@ class OptimizerV2(optimizer_v1.Optimizer):
RuntimeError: If _create_slots has been overridden instead of
_create_vars.
"""
- # Note: We intentionally don't call parent __init__.
-
- # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
- if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
- OptimizerV2._create_slots.__code__):
- raise RuntimeError("Override _create_vars instead of _create_slots when "
- "descending from OptimizerV2 (class %s)" %
- self.__class__.__name__)
- if not name:
- raise ValueError("Must specify the optimizer name")
-
- self._use_locking = use_locking
- self._name = name
- # Map from graph_key to state for that graph. We use the graph_key
- # since it works in both eager and graph mode, and gives the outer
- # graph inside functions.
- tower_context = distribution_strategy_context.get_tower_context()
- if tower_context is None:
- # In a cross-tower context for a DistributionStrategy, which means
- # only one Optimizer will be created, not one per tower.
- self._per_graph_state = {}
- else:
- # We use get_tower_context().merge_call() to get a single dict
- # shared across all model replicas when running with a
- # DistributionStrategy.
- self._per_graph_state = tower_context.merge_call(lambda _: {})
-
- # Hyper parameters, and whether they should be re-evaluated every step.
- self._hyper = {}
-
- def _set_hyper(self, name, value):
- self._hyper[name] = (_is_dynamic(value), value)
-
- def minimize(self, loss, global_step=None, var_list=None,
- gate_gradients=GATE_OP, aggregation_method=None,
- colocate_gradients_with_ops=False, name=None,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Add operations to minimize `loss` by updating `var_list`.
-
- This method simply combines calls `compute_gradients()` and
- `apply_gradients()`. If you want to process the gradient before applying
- them call `compute_gradients()` and `apply_gradients()` explicitly instead
- of using this function.
-
- Args:
- loss: A `Tensor` containing the value to minimize.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- var_list: Optional list or tuple of `Variable` objects to update to
- minimize `loss`. Defaults to the list of variables collected in
- the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- name: Optional name for the returned operation.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- An Operation that updates the variables in `var_list`. If `global_step`
- was not `None`, that operation also increments `global_step`.
-
- Raises:
- ValueError: If some of the variables are not `Variable` objects.
-
- @compatibility(eager)
- When eager execution is enabled, `loss` should be a Python function that
- takes elements of `var_list` as arguments and computes the value to be
- minimized. If `var_list` is None, `loss` should take no arguments.
- Minimization (and gradient computation) is done with respect to the
- elements of `var_list` if not None, else with respect to any trainable
- variables created during the execution of the `loss` function.
- `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
- `grad_loss` are ignored when eager execution is enabled.
- @end_compatibility
- """
- grads_and_vars = self.compute_gradients(
- loss, var_list=var_list, gate_gradients=gate_gradients,
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- grad_loss=grad_loss, stop_gradients=stop_gradients,
- scale_loss_by_num_towers=scale_loss_by_num_towers)
-
- vars_with_grad = [v for g, v in grads_and_vars if g is not None]
- if not vars_with_grad:
- raise ValueError(
- "No gradients provided for any variable, check your graph for ops"
- " that do not support gradients, between variables %s and loss %s." %
- ([str(v) for _, v in grads_and_vars], loss))
-
- return self.apply_gradients(grads_and_vars, global_step=global_step,
- name=name)
-
- def compute_gradients(self, loss, var_list=None,
- gate_gradients=GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Compute gradients of `loss` for the variables in `var_list`.
-
- This is the first part of `minimize()`. It returns a list
- of (gradient, variable) pairs where "gradient" is the gradient
- for "variable". Note that "gradient" can be a `Tensor`, an
- `IndexedSlices`, or `None` if there is no gradient for the
- given variable.
-
- Args:
- loss: A Tensor containing the value to minimize or a callable taking
- no arguments which returns the value to minimize. When eager execution
- is enabled it must be a callable.
- var_list: Optional list or tuple of `tf.Variable` to update to minimize
- `loss`. Defaults to the list of variables collected in the graph
- under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- A list of (gradient, variable) pairs. Variable is always present, but
- gradient can be `None`.
-
- Raises:
- TypeError: If `var_list` contains anything else than `Variable` objects.
- ValueError: If some arguments are invalid.
- RuntimeError: If called with eager execution enabled and `loss` is
- not callable.
-
- @compatibility(eager)
- When eager execution is enabled, `gate_gradients`, `aggregation_method`,
- and `colocate_gradients_with_ops` are ignored.
- @end_compatibility
- """
- # TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if callable(loss):
- with backprop.GradientTape() as tape:
- if var_list is not None:
- tape.watch(var_list)
- loss_value = loss()
-
- # Scale loss for number of towers (callable-loss case). In this case,
- # we have to be careful to call distribute_lib.get_loss_reduction()
- # *after* loss() is evaluated, so we know what loss reduction it uses.
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss_value *= 1. / num_towers
-
- if var_list is None:
- var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
- return list(zip(grads, var_list))
- if context.executing_eagerly():
- raise RuntimeError(
- "`loss` passed to Optimizer.compute_gradients should "
- "be a function when eager execution is enabled.")
-
- # Scale loss for number of towers (non-callable-loss case).
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss *= 1. / num_towers
-
- if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
- optimizer_v1.Optimizer.GATE_OP,
- optimizer_v1.Optimizer.GATE_GRAPH]:
- raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
- "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
- gate_gradients)
- self._assert_valid_dtypes([loss])
- if grad_loss is not None:
- self._assert_valid_dtypes([grad_loss])
- if var_list is None:
- var_list = (
- variables.trainable_variables() +
- ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
- else:
- var_list = nest.flatten(var_list)
- # pylint: disable=protected-access
- var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
- # pylint: enable=protected-access
- processors = [_get_processor(v) for v in var_list]
- if not var_list:
- raise ValueError("No variables to optimize.")
- var_refs = [p.target() for p in processors]
- grads = gradients.gradients(
- loss, var_refs, grad_ys=grad_loss,
- gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- stop_gradients=stop_gradients)
- if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
- grads = control_flow_ops.tuple(grads)
- grads_and_vars = list(zip(grads, var_list))
- self._assert_valid_dtypes(
- [v for g, v in grads_and_vars
- if g is not None and v.dtype != dtypes.resource])
- return grads_and_vars
-
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
- """Apply gradients to variables.
-
- This is the second part of `minimize()`. It returns an `Operation` that
- applies gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs as returned by
- `compute_gradients()`.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- name: Optional name for the returned operation. Default to the
- name passed to the `Optimizer` constructor.
-
- Returns:
- An `Operation` that applies the specified gradients. If `global_step`
- was not None, that operation also increments `global_step`.
-
- Raises:
- TypeError: If `grads_and_vars` is malformed.
- ValueError: If none of the variables have gradients.
- """
- # This is a default implementation of apply_gradients() that can be shared
- # by most optimizers. It relies on the subclass implementing the following
- # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
-
- # Filter out variables with gradients of `None`.
- grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
- if not grads_and_vars:
- raise ValueError("No variables provided.")
- filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
- if not filtered:
- raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, v in grads_and_vars],))
- return distribution_strategy_context.get_tower_context().merge_call(
- self._distributed_apply, filtered, global_step=global_step, name=name)
-
- def _get_or_create_state(self, var_list=None):
- """Either looks up or creates `_OptimizerV2State`.
-
- If any variables are available, they should be passed via the `var_list`
- argument, and these will be used to determine the graph to create/retrieve
- state for. Otherwise the returned state is for the current default graph.
-
- Args:
- var_list: A list of variables to extract a graph from.
-
- Returns:
- An `_OptimizerV2State` object.
- """
- # Determine the graph_key from the current graph.
- eager_execution = context.executing_eagerly()
- if eager_execution or var_list is None:
- graph = ops.get_default_graph()
- else:
- graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
- assert graph is not None
- graph_key = graph._graph_key # pylint: disable=protected-access
-
- # Get the per graph state by looking up the graph_key.
- if graph_key in self._per_graph_state:
- per_graph_state = self._per_graph_state[graph_key]
- else:
- per_graph_state = _OptimizerV2State(self._name)
- per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
- self._per_graph_state[graph_key] = per_graph_state
- return per_graph_state
-
- def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
- """`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce(
- variable_scope.VariableAggregation.SUM, grads_and_vars)
- var_list = [v for _, v in grads_and_vars]
- grads_and_vars = zip(reduced_grads, var_list)
-
- unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
- eager_execution = context.executing_eagerly()
- if eager_execution:
- # Give a clear error in this case instead of "name not supported
- # for Eager Tensors" when we compute non_slot_devices.
- for v in unwrapped_var_list:
- if isinstance(v, ops.Tensor):
- raise NotImplementedError("Trying to update a Tensor ", v)
-
- with ops.name_scope(name, self._name) as name:
- per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
- # Include the current value of any dynamic hyper parameters in `state`.
- non_slot_devices = distribution.non_slot_devices(var_list)
- state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
- self._hyper, distribution, non_slot_devices)
-
- # Create any slot and non-slot variables we need in `state`.
- with ops.init_scope():
- self._create_vars(var_list, state)
-
- with ops.name_scope(name): # Re-enter name_scope created above
- # Give the child class a chance to do something before we start
- # applying gradients.
- self._prepare(state)
-
- def update(v, g):
- """Update variable `v` using gradient `g`."""
- assert v is not None
-
- # Convert the grad to Tensor or IndexedSlices if necessary, and
- # look up a processor for each variable's type.
- try:
- g = ops.convert_to_tensor_or_indexed_slices(g)
- except TypeError:
- raise TypeError(
- "Gradient must be convertible to a Tensor"
- " or IndexedSlices, or None: %s" % g)
- if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
- raise TypeError(
- "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
- processor = _get_processor(v)
-
- # We colocate all ops created in _apply_dense or _apply_sparse
- # on the same device as the variable.
- # TODO(apassos): figure out how to get the variable name here.
- scope_name = "" if eager_execution else v.op.name
- # device_policy is set because non-mirrored tensors will be read in
- # `update_op`.
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with ops.name_scope("update_" + scope_name), \
- context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return processor.update_op(self, g, state)
-
- # Use the processors to update the variables.
- update_ops = []
- for grad, var in grads_and_vars:
- update_ops.extend(distribution.update(var, update, grad, grouped=False))
-
- # Give the child class a chance to do something after applying
- # gradients
- def finish():
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return self._finish(state)
-
- update_ops = control_flow_ops.group(update_ops)
- with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, grouped=False)
- # We said grouped=False, which means finish_updates is always a list.
- # It will be [None] when finish() returns None.
- if finish_updates == [None]:
- finish_updates = [update_ops]
-
- # Update `global_step` (if any).
- if global_step is None:
- apply_updates = distribution.group(finish_updates, name=name)
- else:
- with ops.control_dependencies(finish_updates):
-
- def update_global_step(global_step, name):
- return global_step.assign_add(1, read_value=False, name=name)
-
- apply_updates = distribution.update(
- global_step, update_global_step, name)
-
- # Add the training op to the TRAIN_OP graph collection in graph mode.
- if not eager_execution:
- if isinstance(apply_updates, ops.Tensor):
- apply_updates = apply_updates.op
- train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
- if apply_updates not in train_op:
- train_op.append(apply_updates)
-
- return apply_updates
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- state = self._get_state_for_var(var)
- return state.get_slot(var, name) if state is not None else None
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- state = self._get_per_graph_state()
- return state.get_slot_names() if state is not None else []
-
- def variables(self):
- """A list of variables which encode the current state of `Optimizer`.
-
- Includes slot variables and additional global variables created by the
- optimizer in the current default graph.
-
- Returns:
- A list of variables.
- """
- state = self._get_per_graph_state()
- return state._variables() if state is not None else [] # pylint: disable=protected-access
-
- # --------------
- # Methods to be implemented by subclasses if they want to use the
- # inherited implementation of apply_gradients() or compute_gradients().
- # --------------
- def _create_vars(self, var_list, state):
- """Create all slots needed by the variables and any non-slot variables.
-
- Args:
- var_list: A list of `Variable` objects.
- state: An object with these methods:
- `create_slot(var, val, slot_name, optional_op_name)`,
- `create_slot_with_initializer(`
- `var, initializer, shape, dtype, slot_name, optional_op_name)`,
- `zeros_slot(var, slot_name, optional_op_name)`,
- `create_non_slot_variable(initial_value, name, colocate_with)`,
- `get_hyper(name)`
- """
- # No slots needed by default
- pass
-
- def _prepare(self, state):
- """Code to execute before applying gradients.
-
- Note that most uses of _prepare() in Optimizer have been subsumed
- by explicit support for hyper parameters in OptimizerV2
-
- Args:
- state: An object with a `get_hyper(name)` method.
-
- Returns:
- Return value will be ignored.
- """
- pass
-
- def _apply_dense(self, grad, var, state):
- """Add ops to apply dense gradients to `var`.
-
- Args:
- grad: A `Tensor`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _resource_apply_dense(self, grad, handle, state):
- """Add ops to apply dense gradients to the variable `handle`.
-
- Args:
- grad: a `Tensor` representing the gradient.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to `handle`, with repeated indices.
-
- Optimizers which override this method must deal with repeated indices. See
- the docstring of `_apply_sparse_duplicate_indices` for details. By default
- the correct behavior, to sum non-unique indices and their associated
- gradients, is enforced by first pre-processing `grad` and `indices` and
- passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
- with duplicate indices may instead override this method to avoid the
- overhead of summing.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices may be repeated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- # pylint: disable=protected-access
- summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad, indices=indices)
- # pylint: enable=protected-access
- return self._resource_apply_sparse(
- summed_grad, handle, unique_indices, state)
-
- def _resource_apply_sparse(self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to the variable `handle`.
-
- Similar to `_apply_sparse`, the `indices` argument to this method has been
- de-duplicated. Optimizers which deal correctly with non-unique indices may
- instead override `_resource_apply_sparse_duplicate_indices` to avoid this
- overhead.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices are unique.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
-
- Optimizers which override this method must deal with IndexedSlices objects
- such as the following:
-
- IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
-
- The correct interpretation is:
-
- IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
-
- Many optimizers deal incorrectly with repeated indices when updating based
- on sparse gradients (e.g. summing squares rather than squaring the sum, or
- applying momentum terms multiple times). Adding first is always the correct
- behavior, so this is enforced here by reconstructing the IndexedSlices to
- have only unique indices, then calling _apply_sparse.
-
- Optimizers which deal correctly with repeated indices may instead override
- this method to avoid the overhead of summing indices.
-
- Args:
- grad: `IndexedSlices`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- # pylint: disable=protected-access
- summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad.values, indices=grad.indices)
- # pylint: enable=protected-access
- gradient_no_duplicate_indices = ops.IndexedSlices(
- indices=unique_indices,
- values=summed_values,
- dense_shape=grad.dense_shape)
- return self._apply_sparse(gradient_no_duplicate_indices, var, state)
-
- def _apply_sparse(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`.
-
- The IndexedSlices object passed to `grad` in this function is by default
- pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
- indices (see its docstring for details). Optimizers which can tolerate or
- have correct special cases for duplicate sparse indices may override
- `_apply_sparse_duplicate_indices` instead of this function, avoiding that
- overhead.
-
- Args:
- grad: `IndexedSlices`, with no repeated indices.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _finish(self, state):
- """Do what is needed to finish the update.
-
- This is called inside a scope colocated with any non-slot variables.
-
- Args:
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- The operation to apply updates, or None if no updates.
- """
- return None
-
- # --------------
- # Utility methods for subclasses.
- # --------------
- def _get_per_graph_state(self):
- # pylint: disable=protected-access
- return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
-
- def _get_state_for_var(self, var):
- # pylint: disable=protected-access
- return self._per_graph_state.get(var._graph_key, None)
-
- # --------------
- # Overridden methods from Checkpointable.
- # --------------
-
- def _track_checkpointable(self, *args, **kwargs):
- """Optimizers may not track dependencies. Raises an error."""
- raise NotImplementedError(
- "Optimizers may not have dependencies. File a feature request if this "
- "limitation bothers you.")
-
- @property
- def _checkpoint_dependencies(self):
- """From Checkpointable. Gather graph-specific non-slot variables to save."""
- current_graph_non_slot_variables = []
- state = self._get_per_graph_state()
- if state is not None:
- for name, variable_object in sorted(
- state._non_slot_dict.items(), # pylint: disable=protected-access
- # Avoid comparing variables
- key=lambda item: item[0]):
- current_graph_non_slot_variables.append(
- checkpointable.CheckpointableReference(
- name=name, ref=variable_object))
- # Note: ignores super(); Optimizers may not have any dependencies outside of
- # state objects.
- return current_graph_non_slot_variables
-
- def _lookup_dependency(self, name):
- """From Checkpointable. Find a non-slot variable in the current graph."""
- state = self._get_per_graph_state()
- if state is None:
- return None
- else:
- return state.get_non_slot(name)
-
- @property
- def _deferred_dependencies(self):
- """Lets Checkpointable know where non-slot variables are created.
-
- If necessary, creates a new state object for the current default graph.
- Checkpointable will then add entries to that state's deferred dependency
- dictionary. The state object will check that dictionary when creating
- non-slot variables, restoring their value if an entry is found.
-
- Returns:
- A dictionary which holds deferred dependencies for the current default
- graph.
- """
- state = self._get_or_create_state()
- return state._deferred_dependencies # pylint: disable=protected-access
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable):
- """Checkpointable: Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored.
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- """
- state = self._get_or_create_state(var_list=[variable])
- state._create_or_restore_slot_variable( # pylint: disable=protected-access
- slot_variable_position=slot_variable_position,
- slot_name=slot_name,
- variable=variable,
- optional_op_name=self._name)
-
- # --------------
- # Unsupported parent methods
- # --------------
- def _slot_dict(self, slot_name):
- raise NotImplementedError(
- "_slot_dict() method unsupported in OptimizerV2")
-
- def _get_or_make_slot(self, var, val, slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot() method unsupported in OptimizerV2")
-
- def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot_with_initializer() method unsupported in "
- "OptimizerV2")
-
- def _create_non_slot_variable(self, initial_value, name, colocate_with):
- raise NotImplementedError(
- "_create_non_slot_variable() method unsupported in OptimizerV2")
-
- def _get_non_slot_variable(self, name, graph=None):
- raise NotImplementedError(
- "_get_non_slot_variable() method unsupported in OptimizerV2")
-
- def _non_slot_variables(self):
- raise NotImplementedError(
- "_non_slot_variables() method unsupported in OptimizerV2")
+ super(OptimizerV2, self).__init__(name)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py
index 3de53405ec..090e257ddc 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop.py
@@ -41,19 +41,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.util import deprecation
-from tensorflow.python.training import training_ops
-
-class RMSPropOptimizer(optimizer_v2.OptimizerV2):
+class RMSPropOptimizer(rmsprop.RMSProp):
"""Optimizer that implements the RMSProp algorithm.
See the
[paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self,
learning_rate,
decay=0.9,
@@ -96,138 +98,10 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "RMSProp".
"""
- super(RMSPropOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("decay", decay)
- self._set_hyper("momentum", momentum)
- self._set_hyper("epsilon", epsilon)
-
- self._centered = centered
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- init_rms = state.get_hyper(
- "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
- state.create_slot_with_initializer(v, init_rms, v.get_shape(),
- v.dtype.base_dtype, "rms")
- if self._centered:
- state.zeros_slot(v, "mg")
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- # epsilon is now the rms initial value and is not added to the
- # denominator anymore, hence calling the kernel op with epsilon=0.
- 0,
- grad,
- use_locking=self._use_locking).op
- else:
- return training_ops.apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.resource_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.sparse_apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
- else:
- return training_ops.sparse_apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = self.get_slot(var, "mg")
- return training_ops.resource_sparse_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_sparse_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
+ super(RMSPropOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ rho=decay,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
index 44301ffe9e..83f5971039 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
@@ -157,8 +157,11 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ # TODO(b/117393988): Reduce tolerances for float16.
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=3e-3, half_atol=3e-3)
+ self.assertAllCloseAccordingToType(
+ var1_np, var1.eval(), half_rtol=3e-3, half_atol=3e-3)
@parameterized.parameters([dtypes.float32, dtypes.float64])
def testMinimizeSparseResourceVariable(self, dtype):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 78cea8feb4..0f693e9154 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1110,7 +1110,7 @@ _Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
class AttentionCellWrapper(rnn_cell_impl.RNNCell):
"""Basic attention cell wrapper.
- Implementation based on https://arxiv.org/abs/1409.0473.
+ Implementation based on https://arxiv.org/abs/1601.06733.
"""
def __init__(self,
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
index 360e7dbe75..7743f5b4a7 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -109,6 +109,42 @@ class SparsemaxLossTest(test.TestCase):
np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_loss, tf_loss_op)
+ def _test_sparsemax_loss_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax-loss transfers nan"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_nan = np.asarray([[0, np.nan, 0], [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]]).astype(dtype)
+
+ _, tf_loss_nan = self._tf_sparsemax_loss(z_nan, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan], tf_loss_nan)
+
+ def _test_sparsemax_loss_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax-loss is infinity safe"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, 0],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf,
+ np.inf], [np.inf, np.inf, 0],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, 0], [-np.inf, np.inf,
+ -np.inf]]).astype(dtype)
+
+ _, tf_loss_neg = self._tf_sparsemax_loss(z_neg, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([0.25, np.inf, 0, np.nan], tf_loss_neg)
+
+ _, tf_loss_pos = self._tf_sparsemax_loss(z_pos, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_pos)
+
+ _, tf_loss_mix = self._tf_sparsemax_loss(z_mix, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_mix)
+
def _test_constant_add(self, dtype, random, use_gpu):
"""check sparsemax-loss proposition 3"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -198,6 +234,10 @@ class SparsemaxLossTest(test.TestCase):
self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
+ self._test_sparsemax_loss_of_nan(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_of_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 259e62bd86..c95b9da1e4 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase):
p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+ def _test_sparsemax_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax transfers nan"""
+ z_nan = np.asarray([
+ [0, np.nan, 0],
+ [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan],
+ ]).astype(dtype)
+
+ _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_nan)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax is infinity safe"""
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, -np.inf]]).astype(dtype)
+
+ _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg)
+
+ _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_pos)
+
+ _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_mix)
+
+
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 1"""
z = np.zeros((1, 10))
@@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase):
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
- def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 2"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase):
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
- self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_nan(dtype, random, use_gpu=False)
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_permutation(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
index e617af2ff1..f79c93f347 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -49,7 +49,14 @@ def sparsemax(logits, name=None):
obs = array_ops.shape(logits)[0]
dims = array_ops.shape(logits)[1]
- z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # The mean(logits) can be substracted from logits to make the algorithm
+ # more numerically stable. the instability in this algorithm comes mostly
+ # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
+ # to zero. However, in practise the numerical instability issues are very
+ # minor and substacting the mean causes extra issues with inf and nan
+ # input.
+ z = logits
# sort z
z_sorted, _ = nn.top_k(z, k=dims)
@@ -64,10 +71,24 @@ def sparsemax(logits, name=None):
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
# calculate tau(z)
- indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ # If there are inf values or all values are -inf, the k_z will be zero,
+ # this is mathematically invalid and will also cause the gather_nd to fail.
+ # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
+ # fixed later (see p_safe) by returning p = nan. This results in the same
+ # behavior as softmax.
+ k_z_safe = math_ops.maximum(k_z, 1)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1)
tau_sum = array_ops.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
# calculate p
- return math_ops.maximum(
+ p = math_ops.maximum(
math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis])
+ # If k_z = 0 or if z = nan, then the input is invalid
+ p_safe = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])),
+ array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)),
+ p)
+
+ return p_safe
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
index 582d1e6136..c0438f16bc 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None):
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
labels = ops.convert_to_tensor(labels, name="labels")
- shifted_logits = logits - \
- math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # A constant can be substracted from logits to make the algorithm
+ # more numerically stable in theory. However, there are really no major
+ # source numerical instability in this algorithm.
+ z = logits
# sum over support
- support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
- sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+ # Use a conditional where instead of a multiplication to support z = -inf.
+ # If z = -inf, and there is no support (sparsemax = 0), a multiplication
+ # would cause 0 * -inf = nan, which is not correct in this case.
+ sum_s = array_ops.where(
+ math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)),
+ sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax))
# - z_k + ||q||^2
- q_part = labels * (0.5 * labels - shifted_logits)
+ q_part = labels * (0.5 * labels - z)
+ # Fix the case where labels = 0 and z = -inf, where q_part would
+ # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
+ # z = -inf should be consideredself.
+ # The code below also coveres the case where z = inf. Howeverm in this
+ # caose the sparsemax will be nan, which means the sum_s will also be nan,
+ # therefor this case doesn't need addtional special treatment.
+ q_part_safe = array_ops.where(
+ math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)),
+ array_ops.zeros_like(z), q_part)
- return math_ops.reduce_sum(sum_s + q_part, axis=1)
+ return math_ops.reduce_sum(sum_s + q_part_safe, axis=1)
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index a217397c1a..e9ddec8889 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -11,7 +11,10 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
py_library(
name = "stateless",
- srcs = ["__init__.py"],
+ srcs = [
+ "__init__.py",
+ "python/stateless_ops.py",
+ ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index fe23fe0dd8..30d0a7ab6a 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -32,16 +32,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-
# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_stateless_random_ops import *
+from tensorflow.contrib.stateless.python.stateless_ops import *
from tensorflow.python.util.all_util import remove_undocumented
-ops.NotDifferentiable("StatelessMultinomial")
-ops.NotDifferentiable("StatelessRandomNormal")
-ops.NotDifferentiable("StatelessRandomUniform")
-ops.NotDifferentiable("StatelessTruncatedNormal")
-
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
index d724a5c014..ec5a13b7c6 100644
--- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
+++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
import numpy as np
from tensorflow.contrib import stateless
from tensorflow.python.framework import constant_op
@@ -27,10 +29,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform),
- (stateless.stateless_random_normal, random_ops.random_normal),
- (stateless.stateless_truncated_normal, random_ops.truncated_normal)]
-
def invert_philox(key, value):
"""Invert the Philox bijection."""
@@ -51,90 +49,30 @@ def invert_philox(key, value):
class StatelessOpsTest(test.TestCase):
- def testMatchStateful(self):
+ def _test_match(self, cases):
# Stateless ops should be the same as stateful ops on the first call
# after seed scrambling.
+ cases = tuple(cases)
key = 0x3ec8f720, 0x02461e29
for seed in (7, 17), (11, 5), (2, 3):
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
preseed = preseed[::2] | preseed[1::2] << 32
random_seed.set_random_seed(seed[0])
with self.test_session(use_gpu=True):
- for stateless_op, stateful_op in CASES:
- for shape in (), (3,), (2, 5):
- stateful = stateful_op(shape, seed=seed[1])
- pure = stateless_op(shape, seed=preseed)
- self.assertAllEqual(stateful.eval(), pure.eval())
+ for stateless_op, stateful_op in cases:
+ stateful = stateful_op(seed=seed[1])
+ pure = stateless_op(seed=preseed)
+ self.assertAllEqual(stateful.eval(), pure.eval())
- def testDeterminism(self):
+ def _test_determinism(self, cases):
# Stateless values should be equal iff the seeds are equal (roughly)
+ cases = tuple(cases)
with self.test_session(use_gpu=True):
for seed_type in [dtypes.int32, dtypes.int64]:
seed_t = array_ops.placeholder(seed_type, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(shape, seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testShapeType(self):
- with self.test_session(use_gpu=True):
- for shape_dtype in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype),
- seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testMatchStatefulMultinomial(self):
- # Stateless ops should be the same as stateful ops on the first call
- # after seed scrambling.
- key = 0x3ec8f720, 0x02461e29
- num_samples = 4
- for logits_dtype in np.float16, np.float32, np.float64:
- for output_dtype in dtypes.int32, dtypes.int64:
- for seed in (7, 17), (11, 5), (2, 3):
- preseed = invert_philox(key,
- (seed[0], 0, seed[1], 0)).astype(np.uint64)
- preseed = preseed[::2] | preseed[1::2] << 32
- random_seed.set_random_seed(seed[0])
- with self.test_session(use_gpu=True):
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- logits_t = constant_op.constant(logits, dtype=logits_dtype)
- stateful = random_ops.multinomial(
- logits_t,
- num_samples,
- seed=seed[1],
- output_dtype=output_dtype)
- pure = stateless.stateless_multinomial(
- logits_t,
- num_samples,
- seed=preseed,
- output_dtype=output_dtype)
- self.assertAllEqual(stateful.eval(), pure.eval())
-
- def testDeterminismMultinomial(self):
- # Stateless values should be equal iff the seeds are equal (roughly)
- num_samples = 10
- with self.test_session(use_gpu=True):
- for seed_type in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(seed_type, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- pure = stateless.stateless_multinomial(
- logits, num_samples, seed=seed_t)
+ for stateless_op, _ in cases:
+ pure = stateless_op(seed=seed_t)
values = [
(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
]
@@ -142,6 +80,74 @@ class StatelessOpsTest(test.TestCase):
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
+ def _float_cases(self, shape_dtypes=(None,)):
+ float_cases = (
+ # Uniform distribution, with and without range
+ (stateless.stateless_random_uniform, random_ops.random_uniform, {}),
+ (stateless.stateless_random_uniform, random_ops.random_uniform,
+ dict(minval=2.2, maxval=7.1)),
+ # Normal distribution, with and without mean+stddev
+ (stateless.stateless_random_normal, random_ops.random_normal, {}),
+ (stateless.stateless_random_normal, random_ops.random_normal,
+ dict(mean=2, stddev=3)),
+ # Truncated normal distribution, with and without mean+stddev
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal,
+ dict(mean=3, stddev=4)),
+ )
+ for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for stateless_op, stateful_op, kwds in float_cases:
+ kwds = dict(shape=shape, dtype=dtype, **kwds)
+ yield (functools.partial(stateless_op, **kwds),
+ functools.partial(stateful_op, **kwds))
+
+ def _int_cases(self, shape_dtypes=(None,)):
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for dtype in dtypes.int32, dtypes.int64:
+ kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape)
+ yield (functools.partial(stateless.stateless_random_uniform, **kwds),
+ functools.partial(random_ops.random_uniform, **kwds))
+
+ def _multinomial_cases(self):
+ num_samples = 10
+ for logits_dtype in np.float16, np.float32, np.float64:
+ for output_dtype in dtypes.int32, dtypes.int64:
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ kwds = dict(
+ logits=constant_op.constant(logits, dtype=logits_dtype),
+ num_samples=num_samples,
+ output_dtype=output_dtype)
+ yield (functools.partial(stateless.stateless_multinomial, **kwds),
+ functools.partial(random_ops.multinomial, **kwds))
+
+ def testMatchFloat(self):
+ self._test_match(self._float_cases())
+
+ def testMatchInt(self):
+ self._test_match(self._int_cases())
+
+ def testMatchMultinomial(self):
+ self._test_match(self._multinomial_cases())
+
+ def testDeterminismFloat(self):
+ self._test_determinism(
+ self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismInt(self):
+ self._test_determinism(
+ self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismMultinomial(self):
+ self._test_determinism(self._multinomial_cases())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/stateless/python/stateless_ops.py b/tensorflow/contrib/stateless/python/stateless_ops.py
new file mode 100644
index 0000000000..1449825c83
--- /dev/null
+++ b/tensorflow/contrib/stateless/python/stateless_ops.py
@@ -0,0 +1,214 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stateless random ops which take seed as a tensor input."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_stateless_random_ops
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import math_ops
+
+ops.NotDifferentiable("StatelessMultinomial")
+ops.NotDifferentiable("StatelessRandomNormal")
+ops.NotDifferentiable("StatelessRandomUniform")
+ops.NotDifferentiable("StatelessRandomUniformInt")
+ops.NotDifferentiable("StatelessTruncatedNormal")
+
+
+def stateless_random_uniform(shape,
+ seed,
+ minval=0,
+ maxval=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a uniform distribution.
+
+ This is a stateless version of `tf.random_uniform`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a uniform distribution in the range
+ `[minval, maxval)`. The lower bound `minval` is included in the range, while
+ the upper bound `maxval` is excluded.
+
+ For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
+ be specified explicitly.
+
+ In the integer case, the random integers are slightly biased unless
+ `maxval - minval` is an exact power of two. The bias is small for values of
+ `maxval - minval` significantly smaller than the range of the output (either
+ `2**32` or `2**64`).
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
+ range of random values to generate. Defaults to 0.
+ maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
+ range of random values to generate. Defaults to 1 if `dtype` is floating
+ point.
+ dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
+ `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random uniform values.
+
+ Raises:
+ ValueError: If `dtype` is integral and `maxval` is not specified.
+ """
+ dtype = dtypes.as_dtype(dtype)
+ if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
+ dtypes.float64, dtypes.int32, dtypes.int64):
+ raise ValueError("Invalid dtype %r" % dtype)
+ if maxval is None:
+ if dtype.is_integer:
+ raise ValueError("Must specify maxval for integer dtype %r" % dtype)
+ maxval = 1
+ with ops.name_scope(name, "stateless_random_uniform",
+ [shape, seed, minval, maxval]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
+ maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
+ if dtype.is_integer:
+ return gen_stateless_random_ops.stateless_random_uniform_int(
+ shape, seed=seed, minval=minval, maxval=maxval, name=name)
+ else:
+ rnd = gen_stateless_random_ops.stateless_random_uniform(
+ shape, seed=seed, dtype=dtype)
+ return math_ops.add(rnd * (maxval - minval), minval, name=name)
+
+
+def stateless_random_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a normal distribution.
+
+ This is a stateless version of `tf.random_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
+ distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random normal values.
+ """
+ with ops.name_scope(name, "stateless_random_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_truncated_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values, truncated normally distributed.
+
+ This is a stateless version of `tf.truncated_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a normal distribution with specified mean and
+ standard deviation, except that values whose magnitude is more than 2 standard
+ deviations from the mean are dropped and re-picked.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution, before truncation.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random truncated normal values.
+ """
+ with ops.name_scope(name, "stateless_truncated_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_truncated_normal(
+ shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_multinomial(logits,
+ num_samples,
+ seed,
+ output_dtype=dtypes.int64,
+ name=None):
+ """Draws deterministic pseudorandom samples from a multinomial distribution.
+
+ This is a stateless version of `tf.multinomial`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Example:
+
+ ```python
+ # samples has shape [1, 5], where each value is either 0 or 1 with equal
+ # probability.
+ samples = tf.contrib.stateless.stateless_multinomial(
+ tf.log([[10., 10.]]), 5, seed=[7, 17])
+ ```
+
+ Args:
+ logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice
+ `[i, :]` represents the unnormalized log-probabilities for all classes.
+ num_samples: 0-D. Number of independent samples to draw for each row slice.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ name: Optional name for the operation.
+ output_dtype: integer type to use for the output. Defaults to int64.
+
+ Returns:
+ The drawn samples of shape `[batch_size, num_samples]`.
+ """
+ with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
+ logits = ops.convert_to_tensor(logits, name="logits")
+ return gen_stateless_random_ops.stateless_multinomial(
+ logits, num_samples, seed, output_dtype=output_dtype)
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 10ed1c2891..8c36d5a297 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -302,6 +302,7 @@ tf_py_test(
"//tensorflow/python:client_testlib",
":datasets",
],
+ flaky = 1, # TODO(b/117363808): fails 1/1000 OSS runs
grpc_enabled = True,
)
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index f88dc51636..1e66801efd 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -168,6 +168,12 @@ message RunEnvironmentResult {
optional HostIndependentJobInfoResult host_independent_job_info = 5;
// Host-dependent job information.
repeated HostDependentJobInfoResult host_dependent_job_info = 6;
+ // The number of replicas, corresponds to input parallelism.
+ // If there is no model parallelism, replica_count = tpu_core_count
+ optional int32 replica_count = 7;
+ // The number of cores used for a single replica, e.g. model parallelism.
+ // If there is no model parallelism, then num_cores_per_replica = 1
+ optional int32 num_cores_per_replica = 8;
}
// The types of host operations that are tracked.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index 8529b48c15..c2e3be03db 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -62,9 +62,9 @@ message FtrlParameters {
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
// order to get correct results; a warning will be printed otherwise (which may
-// change to an error in the future). If use_max_with_epsilon is set, the Adam
+// change to an error in the future). If use_sum_inside_sqrt is set, the Adam
// variable update formula will be changed from m / (sqrt(v) + epsilon) to
-// m / max(sqrt(v), abs(epsilon)); this option improves the performance of TPU
+// m / sqrt(v + epsilon**2); this option improves the performance of TPU
// training and is not expected to harm model quality.
message AdamParameters {
float beta1 = 3;
@@ -73,7 +73,7 @@ message AdamParameters {
float initial_m = 6;
float initial_v = 7;
bool use_non_lazy_adam = 8;
- bool use_max_with_epsilon = 9;
+ bool use_sum_inside_sqrt = 10;
}
// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
new file mode 100644
index 0000000000..280148e032
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "LeakyRelu"
+ visibility: HIDDEN
+ summary: "Computes rectified linear: `max(features, features * alpha)`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
new file mode 100644
index 0000000000..e427526602
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LeakyReluGrad"
+ visibility: HIDDEN
+ in_arg {
+ name: "gradients"
+ description: <<END
+The backpropagated gradients to the corresponding LeakyRelu operation.
+END
+ }
+ in_arg {
+ name: "features"
+ description: <<END
+The features passed as input to the corresponding LeakyRelu operation,
+OR the outputs of that operation (both work equivalently).
+END
+ }
+ out_arg {
+ name: "backprops"
+ description: <<END
+`gradients * (features > 0) + alpha * gradients * (featurs <= 0)`.
+END
+ }
+ summary: "Computes rectified linear gradients for a LeakyRelu operation."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
new file mode 100644
index 0000000000..b6a6dbdf54
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
@@ -0,0 +1,46 @@
+op {
+ graph_op_name: "StatelessRandomUniformInt"
+ visibility: HIDDEN
+ in_arg {
+ name: "shape"
+ description: <<END
+The shape of the output tensor.
+END
+ }
+ in_arg {
+ name: "seed"
+ description: <<END
+2 seeds (shape [2]).
+END
+ }
+ in_arg {
+ name: "minval"
+ description: <<END
+Minimum value (inclusive, scalar).
+END
+ }
+ in_arg {
+ name: "maxval"
+ description: <<END
+Maximum value (exclusive, scalar).
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Random values with specified shape.
+END
+ }
+ attr {
+ name: "dtype"
+ description: <<END
+The type of the output.
+END
+ }
+ summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
+ description: <<END
+The generated values follow a uniform distribution in the range `[minval, maxval)`.
+
+The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
+END
+}
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index db137f1a19..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -466,23 +466,23 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
- int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
+ int64 max_constant_size_in_bytes,
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
// 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 3) If the size of the constant in bytes is too large (>
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
+ // constraint, do not replace it.
+ // 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
- // 4) If the constant op created does not have a kernel implementation
+ // 5) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) If the constant op for the device has different output memory type
- // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if (memory_type == HOST_MEMORY && !is_int32) {
+ if ((memory_type == HOST_MEMORY && !is_int32) ||
+ (memory_type == DEVICE_MEMORY && is_int32)) {
return false;
}
}
@@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (!disable_memory_output_type_check) {
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
- }
- }
- }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
@@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c],
- control_deps, opts.max_constant_size_in_bytes,
- opts.disable_memory_output_type_check, generate_new_name)) {
+ control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
++num_nodes_replaced;
}
}
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index 4c71b7bd27..a9a84f761b 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -45,10 +45,6 @@ struct ConstantFoldingOptions {
// optimization.
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
- // If disable_memory_output_type_check is true, we will disable output memory
- // type check for constant node replacement.
- bool disable_memory_output_type_check = false;
-
// A generator for the name suffix of constant folded nodes. A
// default id generator that monotonically increases is used if nullptr is
// passed.
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 2c48084cab..40ec1502da 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -1240,6 +1241,7 @@ class ExecutorState {
StepStatsCollectorInterface* const stats_collector_;
const tracing::TraceCollector* const trace_collector_;
const tracing::EventCollector* const event_collector_;
+ Context context_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
trace_collector_(tracing::GetTraceCollector()),
event_collector_(
tracing::GetEventCollector(tracing::EventCategory::kCompute)),
+ context_(ContextKind::kThread),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
@@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item,
}
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
+ WithContext wc(context_);
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
TaggedNodeReadyQueue inline_ready;
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 91194bc86f..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -39,8 +39,7 @@ void GraphOptimizer::Optimize(
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn,
- const std::function<bool(const Node*)>& cf_consider_fn,
- bool cf_disable_memory_output_type_check) {
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -65,8 +64,6 @@ void GraphOptimizer::Optimize(
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
cf_opts.consider = cf_consider_fn;
- cf_opts.disable_memory_output_type_check =
- cf_disable_memory_output_type_check;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 8954e9612d..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -47,16 +47,13 @@ class GraphOptimizer {
// returns true will be considered for CSE.
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
// returns true will be considered for CF.
- // If cf_disable_memory_output_type_check is true, CF will discard output
- // memory type check for constant node replacement.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
- const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
- bool cf_disable_memory_output_type_check = false);
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index a02084f223..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index fa4d1eda62..9488a44778 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
+ // Note: it's possible, if the node's been updated, that the shape inference
+ // context doesn't have the right number of outputs.
+ if (node->num_outputs() > c->num_outputs()) {
+ TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
+ }
// Check compatibility, and merge the shapes.
ShapeHandle existing_shape = c->output(output_port);
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7a4a0096fa..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -142,6 +142,19 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+void Node::UpdateProperties() {
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ Status status =
+ InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed at updating node: " << status;
+ return;
+ }
+ props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
+ inputs, outputs);
+}
+
const string& Node::name() const { return props_->node_def.name(); }
const string& Node::type_string() const { return props_->node_def.op(); }
const NodeDef& Node::def() const { return props_->node_def; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 2944951f82..228b1331d9 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -171,6 +171,7 @@ class Node {
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
+ UpdateProperties();
}
void ClearAttr(const string& name);
@@ -211,6 +212,10 @@ class Node {
// e.g. in AddAttr.
void MaybeCopyOnWrite();
+ // Called after an attr has changed. Decides whether we need to update some
+ // property of the node (stored in props_).
+ void UpdateProperties();
+
AttrValue* AddAttrHelper(const string& name);
// A set of mutually exclusive classes for different kinds of nodes,
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index d92874909f..68a20fcc5f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) {
strings::StrCat("Attempt to add nullptr Node to node with type ",
def_builder_.op_def().name()));
} else {
- errors_.emplace_back(
- strings::StrCat("Attempt to add output ", i, " of ", node->name(),
- " not in range [0, ", node->num_outputs(),
- ") to node with type ", def_builder_.op_def().name()));
+ errors_.emplace_back(strings::StrCat(
+ "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
+ node->num_outputs(), ") to node with type ",
+ def_builder_.op_def().name(), ". Node: ", node->DebugString()));
}
}
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 1b5a215987..cbf5c8e038 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -102,15 +102,19 @@ bool IsConjugateTranspose(const NodeDef& node) {
}
bool IsControlFlow(const NodeDef& node) {
- // clang-format off
- return node.op() == "ControlTrigger" ||
- node.op() == "Enter" ||
- node.op() == "Exit" ||
- node.op() == "LoopCond" ||
- node.op() == "Merge" ||
- node.op() == "NextIteration" ||
- node.op() == "Switch";
- // clang-format on
+ // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
+ // string comparison.
+ static const gtl::FlatSet<string>* const kControFlowOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "ControlTrigger",
+ "Enter",
+ "Exit",
+ "LoopCond",
+ "Merge",
+ "NextIteration",
+ "Switch",
+ }));
+ return kControFlowOps->count(node.op()) > 0;
}
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 37aa24b947..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -13,9 +13,19 @@ VECTORIZER_DEPS = [
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index 3af6bab409..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -19,13 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
@@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer {
auto new_cast_node = outer_scope->AddNode(node.def(), &s);
TF_RETURN_IF_ERROR(s);
- // Add input and output mappings
- input_ports->push_back({new_cast_node, 0});
- output_ports->push_back({new_cast_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 74ce520ce1..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -19,15 +19,15 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
- if (node.num_inputs() != 1) {
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
@@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer {
int new_axis = node.def().attr().at("axis").i() + 1;
new_unpack_node->AddAttr("axis", new_axis);
- // Add the input mappings
- input_ports->push_back({new_unpack_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
// Add the output mappings
int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- output_ports->push_back({new_unpack_node, i});
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index 56eb88c95e..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -18,15 +18,12 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
-
-// Describes a tensor with its operation Node and output position
-typedef std::pair<Node*, int> Port;
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -36,17 +33,17 @@ class Vectorizer {
// Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs. The new Node(s) collectively have the
+ // on elements of `inputs`. The new Node(s) collectively have the
// same number of input and output ports as the node being converted.
- // Adds mappings for the new nodes' input and output ports to `inputs` and
- // `outputs` respectively, where the i'th Port in inputs/outputs
- // corresponds to the i'th input/output port of the node to be converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
virtual Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) = 0;
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 663ceba027..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* inputs,
- std::vector<Port>* outputs) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) {
NodeDef node_def;
Status s;
Node* node = g.AddNode(node_def, &s);
- std::vector<Port> inputs, outputs;
- EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 344c420902..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -45,22 +45,6 @@ namespace {
// Describes a tensor with its operation Node and output position
typedef std::pair<Node*, int> TensorDesc;
-// Equivalent to python Pfor's WrappedTensor struct
-struct WrappedTensor {
- TensorDesc tensor;
-
- // Whether the tensor is stacked, i.e. represents the results of applying
- // the operation on all slices of the input, where each row i of the
- // tensor corresponds to the op's output on slice i of the input. False
- // if the tensor is not stacked, i.e. represents the result of the op on
- // a single slice of the input, where the result does not vary between
- // slices.
- bool stacked;
-
- WrappedTensor(TensorDesc&& tensor, bool stacked)
- : tensor(std::move(tensor)), stacked(stacked) {}
-};
-
const char* const kRetValOp = "_Retval";
void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
@@ -239,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
return errors::Unimplemented("No vectorizer registered for op: ",
op_node->type_string());
}
- std::vector<Port> input_ports, output_ports;
- input_ports.reserve(op_node->num_inputs());
- output_ports.reserve(op_node->num_outputs());
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- &input_ports, &output_ports));
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
- if (op_node->num_outputs() != output_ports.size() ||
- op_node->num_inputs() != input_ports.size() ||
- input_edges.size() != input_ports.size()) {
- return errors::Internal("Vectorizer inputs/outputs don't match.");
- }
-
- // Promote the inputs of the op to MapDefun outputs and connect the edges
- // accordingly.
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
for (size_t i = 0; i < op_node->num_inputs(); ++i) {
auto edge = input_edges[i];
- TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
- {edge->src(), edge->src_output()}));
- outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
- input_ports[i].first, input_ports[i].second);
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
+
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}});
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
return Status::OK();
@@ -281,25 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) {
TensorDesc output({ret_edge->src(), ret_edge->src_output()});
TensorDesc converted_output;
- if (auto found = gtl::FindOrNull(conversion_map_, output)) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (found->stacked) {
- converted_output = found->tensor;
- } else {
- // Some outputs may be unstacked if they don't derive from arg nodes
- // (for example, if a function returns a constant). For these, we
- // have to add extra nodes to tile it in the 0th dimension.
- TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
- }
- } else {
- // Note: All unstacked nodes are converted ahead of time in `Initialize`,
- // and here we assume that all op vectorizers create only stacked outputs.
- // This may not hold in the future, as more vectorizers are added that
- // may actually create unstacked outputs. For example, see the `Shape`
- // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output).tensor;
+ found = &conversion_map_.at(output);
+ }
+
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -455,7 +450,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* ones_shape;
TF_RETURN_IF_ERROR(node_builder("Shape")
- .Input(unstacked->tensor.first) // input
+ .Input(unstacked->node) // input
.Finalize(g, &ones_shape));
Node* ones;
@@ -473,8 +468,8 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* expand_dims;
TF_RETURN_IF_ERROR(node_builder("ExpandDims")
- .Input(unstacked->tensor.first) // input
- .Input(const_0) // dim
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
.Finalize(g, &expand_dims));
TF_RETURN_IF_ERROR(node_builder("Tile")
@@ -491,11 +486,11 @@ Status Vectorization::AddArgNodeMappings() {
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
// Control inputs
conversion_map_.insert({{arg_node, Graph::kControlSlot},
- {{input_node, Graph::kControlSlot}, true}});
+ {input_node, Graph::kControlSlot, true}});
}
return Status::OK();
}
@@ -541,7 +536,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
if (auto found = gtl::FindOrNull(conversion_map_,
{edge->src(), edge->src_output()})) {
- outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node,
+ outer_scope_->AddEdge(found->node, found->output_index, node,
edge->dst_input());
} else {
status->Update(errors::Internal(
@@ -552,11 +547,10 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
// Add output mappings
for (int i = 0; i < tensor.first->num_outputs(); ++i) {
- conversion_map_.insert(
- {{tensor.first, i}, WrappedTensor({node, i}, false)});
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
}
conversion_map_.insert({{tensor.first, Graph::kControlSlot},
- WrappedTensor({node, Graph::kControlSlot}, false)});
+ WrappedTensor(node, Graph::kControlSlot, false)});
return true;
}
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h
index 765dd13263..bd6bf9f860 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h
@@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
+#include <atomic>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace grappler {
@@ -29,6 +32,7 @@ struct GrapplerItem;
// optimization of a GrapplerItem for running on a cluster.
class GraphOptimizer {
public:
+ GraphOptimizer() : is_cancelled_(false) {}
virtual ~GraphOptimizer() {}
virtual string name() const = 0;
@@ -45,8 +49,25 @@ class GraphOptimizer {
// call to Optimize) performed. Lower "result" scores are better.
virtual void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) = 0;
+
+ // Best effort cancellation. Sets is_cancelled to true and requests that the
+ // optimizer returns as soon as possible from active calls to Optimize() or
+ // FeedBack().
+ void Cancel() { is_cancelled_ = true; }
+
+ bool is_cancelled() const { return is_cancelled_; }
+
+ private:
+ std::atomic<bool> is_cancelled_;
};
+#define GRAPPLER_RETURN_IF_CANCELLED() \
+ do { \
+ if (is_cancelled()) { \
+ return errors::DeadlineExceeded(this->name(), " was cancelled."); \
+ } \
+ } while (0)
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c3d70a1fdf..7488cedec5 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+#include <memory>
+
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -37,7 +40,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -107,13 +114,29 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
- MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
+ MK_OPT("pin_to_host",
+ new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
#undef MK_OPT
+MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
+ : cpu_device_(cpu_device), cfg_(cfg) {
+ // TODO(rmlarsen): Increase kNumThreads to, say, port::NumSchedulableCPUs()
+ // if we want to the threadpool for parallelizing Grappler
+ const int kNumThreads = 1;
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(
+ Env::Default(), "MetaOptimizerThreadPool", kNumThreads);
+}
+
+MetaOptimizer::~MetaOptimizer() {
+ // The ThreadPool destructor waits for threads to finish, so we don't
+ // pull the rug out from under them.
+ thread_pool_.reset();
+}
+
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
if (cfg_.disable_meta_optimizer()) {
@@ -139,7 +162,7 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
- if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
@@ -309,6 +332,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
+ GRAPPLER_RETURN_IF_CANCELLED();
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
// Some must run only on the last iteration.
@@ -367,6 +391,7 @@ Status MetaOptimizer::RunOptimizer(
// resets optimized_graph to an empty graph.
optimized_graph->Swap(&optimized_item->graph);
*optimized_graph = GraphDef();
+ // TODO(rmlarsen): Add timeout for individual optimizers.
Status status =
optimizer->Optimize(cluster, *optimized_item, optimized_graph);
uint64 end_us = Env::Default()->NowMicros();
@@ -388,14 +413,15 @@ Status MetaOptimizer::RunOptimizer(
return status;
}
-Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+Status MetaOptimizer::OptimizeMainGraphAndFunctionLibrary(
+ Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
VLOG(1) << "Optimized main graph.";
+ GRAPPLER_RETURN_IF_CANCELLED();
// Skip optimizing functions if this is a TPU graph. Currently, Grappler
// passes do not handle TPU functions correctly in a variety of ways (Note
@@ -431,6 +457,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimize_function_library = false;
for (const FunctionDef& func : optimized_graph->library().function()) {
+ GRAPPLER_RETURN_IF_CANCELLED();
+
const string& func_name = func.signature().name();
// Skip already optimized functions.
@@ -505,6 +533,43 @@ void MetaOptimizer::PrintResult() {
}
}
+Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ const int64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
+ const int64 timeout_usec = (cfg_.meta_optimizer_timeout_ms() == 0
+ ? kFiveMinutesInUsec
+ : cfg_.meta_optimizer_timeout_ms() * 1000);
+ if (timeout_usec < 0) {
+ return OptimizeMainGraphAndFunctionLibrary(cluster, item, optimized_graph);
+ }
+
+ GraphDef optimized_with_timeout;
+ Status status;
+ Notification done;
+ thread_pool_->Schedule(
+ [this, cluster, &done, &optimized_with_timeout, &item, &status]() {
+ status = this->OptimizeMainGraphAndFunctionLibrary(
+ cluster, item, &optimized_with_timeout);
+ done.Notify();
+ });
+
+ const bool notified = WaitForNotificationWithTimeout(&done, timeout_usec);
+ if (notified && status.ok()) {
+ optimized_graph->Swap(&optimized_with_timeout);
+ } else {
+ *optimized_graph = item.graph;
+ if (!notified) {
+ this->Cancel();
+ done.WaitForNotification();
+ status = errors::DeadlineExceeded(
+ "Grappler MetaOptimizer timed out after ",
+ static_cast<float>(timeout_usec) / (1000 * 1000), " seconds");
+ LOG(WARNING) << status.error_message();
+ }
+ }
+ return status;
+}
+
void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& pruned_graph, double result) {
// Nothing to do for MetaOptimizer.
@@ -527,7 +592,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
- cfg.pin_to_host_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 99a0a33ffa..35d6a4559b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -28,9 +29,8 @@ namespace grappler {
// Run the other grappler optimizers based on the specified rewriter config.
class MetaOptimizer : public GraphOptimizer {
public:
- MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
- : cpu_device_(cpu_device), cfg_(cfg) {}
- ~MetaOptimizer() override = default;
+ MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg);
+ ~MetaOptimizer();
string name() const override { return "meta_optimizer"; };
@@ -65,9 +65,18 @@ class MetaOptimizer : public GraphOptimizer {
Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph);
+ // Run optimization passes over the main graph and for functions in the
+ // function library.
+ Status OptimizeMainGraphAndFunctionLibrary(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph);
+
DeviceBase* const cpu_device_; // may be NULL
RewriterConfig cfg_;
+ // Thread pool used for launching optimizers asynchronously.
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+
struct OptimizerResult {
string optimizer_name;
string result;
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 3f3f43382f..7f1dd91f09 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -461,6 +461,68 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
}
+class SleepingOptimizer : public CustomGraphOptimizer {
+ public:
+ SleepingOptimizer() {}
+ string name() const override { return "test_optimizer"; }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ optimized_graph->add_node();
+ sleep(1);
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+};
+
+REGISTER_GRAPH_OPTIMIZER(SleepingOptimizer);
+
+TEST_F(MetaOptimizerTest, OptimizerTimesOut) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("SleepingOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+ rewriter_config.set_meta_optimizer_timeout_ms(1500);
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ EXPECT_EQ(status.error_message(),
+ "Grappler MetaOptimizer timed out after 1.5 seconds");
+ // Make sure the graph was reverted to the original regardless of when the
+ // optimizer timed out.
+ CompareGraphs(item.graph, output);
+}
+
+TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("SleepingOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+ rewriter_config.set_meta_optimizer_timeout_ms(1500);
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(item.graph.node_size() + 1, output.node_size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 8ed4271fa4..29a3b2b74c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,16 +25,29 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
+
namespace internal {
+namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
+struct OpDevicePortHasher {
+ std::size_t operator()(const std::tuple<string, string, int>& x) const {
+ uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
+
+ return Hash64Combine(code, hash<int>()(std::get<2>(x)));
+ }
+};
+using OpDevicePortOnHostMap =
+ gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
+
// All the nodes that should be blacklisted and not swapped.
bool IsBlacklisted(const NodeDef& node) {
return
@@ -82,10 +95,10 @@ Status TryFindKernelDef(const std::vector<DeviceType>& devices,
// Checks if a node's output port is host friendly.
// Roughly this means checking if the output port is on Host memory.
-Status IsNodeOutputPortHostFriendly(const GraphView& graph,
- GraphProperties* properties,
- const NodeDef& node, int port_id,
- bool* is_candidate) {
+Status IsNodeOutputPortHostFriendly(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
// Make sure we are not a blacklisted op.
@@ -117,7 +130,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
for (const auto& fanin : graph.GetFanins(node, false)) {
bool fanin_candidate = false;
TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
if (!fanin_candidate) {
return Status::OK();
}
@@ -132,11 +146,22 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
return Status::OK();
}
+ // Check `op_device_outport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ const std::tuple<string, string, int> cache_key(node.op(), node.device(),
+ port_id);
+ auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_outport_pinned_to_host_cache->end()) {
+ *is_candidate = it->second;
+ return Status::OK();
+ }
+
// Check if op's output port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -146,6 +171,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
LOG(WARNING) << "Invalid port: " << port_id << "!\n"
<< node.DebugString() << "\n"
<< op->DebugString();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -155,6 +181,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
&kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -166,22 +193,35 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
}
}
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
+
return Status::OK();
}
// Checks if a node's input port is Host friendly.
// Roughly this means checking if the input port is on Host memory.
-bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
+bool IsNodeInputPortHostFriendly(
+ const NodeDef& node, int port_id,
+ OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
// If node is on Host, assume its inputs are Host friendly.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
return true;
}
+ // Check `op_device_inport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
+ auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_inport_pinned_to_host_cache->end()) {
+ return it->second;
+ }
+
// Check if op's input port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
@@ -192,16 +232,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
// Check if the input_arg is pinned to Host.
for (const string& host_memory_arg : kernel->host_memory_arg()) {
if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
return true;
}
}
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+
return false;
}
@@ -211,18 +255,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
// 2] Check if node can run on Host.
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
// ints, and pinned to Host).
-Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
- const NodeDef& node, bool* is_candidate) {
+Status IsNodeHostCandidate(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- *is_candidate = true;
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
return Status::OK();
}
- // Skip these node types.
- if (IsBlacklisted(node)) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
return Status::OK();
}
@@ -232,17 +278,6 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
return Status::OK();
}
- // Check all inputs are Host friendly.
- for (const GraphView::OutputPort& fanin :
- graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
- bool fanin_candidate = false;
- TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
- if (!fanin_candidate) {
- return Status::OK();
- }
- }
-
// Check all outputs are Host friendly.
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
@@ -255,16 +290,42 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
}
}
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+
*is_candidate = true;
return Status::OK();
}
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device) {
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace
+
+// Tries to swap `device` to a Host device from `devices`. Returns true iff
+// there was a swap.
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device) {
// Force this node onto the CPU.
- if (device.empty() && has_device_cpu) {
- return "/device:CPU:0";
- } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ if (device->empty() && has_device_cpu) {
+ *device = "/device:CPU:0";
+ return true;
+ } else if (str_util::StrContains(*device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -272,27 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ strings::StrCat(device->substr(0, device->rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- return device_host;
+ *device = device_host;
+ return true;
}
}
}
- // We couldn't find an appropriate Host device, return original device.
- return device;
-}
-
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
+ // We couldn't find an appropriate Host device, return false.
return false;
}
+
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -324,20 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
+ // Cache to map {op, device, port} -> bool on whether it is pinned to host.
+ internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
+ internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
+
for (auto& node : *optimized_graph->mutable_node()) {
bool is_candidate = false;
- TF_RETURN_IF_ERROR(
- internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
+ TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
+ graph, &properties, node, &op_device_outport_pinned_to_host_cache,
+ &is_candidate));
if (!is_candidate) {
continue;
}
- if (IsConstant(node)) {
- const_nodes.emplace_back(&node, node.device());
+ const string original_device = node.device();
+ const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
+ node.mutable_device());
+ // Keep track of all Const nodes that we swapped.
+ if (swapped && IsConstant(node)) {
+ const_nodes.emplace_back(&node, original_device);
}
- // Try and swap the device to Host.
- node.set_device(
- internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -349,8 +408,9 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
// The consumer is not Host friendly, swap it back to the original device.
- if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
- fanout.port_id)) {
+ if (!internal::IsNodeInputPortHostFriendly(
+ *fanout.node, fanout.port_id,
+ &op_device_inport_pinned_to_host_cache)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index d557a03463..bed4a9ef95 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device);
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 7c64529441..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 451f8c1a6c..37c1c54786 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -45,6 +45,16 @@ cc_library(
],
)
+tf_cc_test(
+ name = "dataset_utils_test",
+ srcs = ["dataset_utils_test.cc"],
+ deps = [
+ ":dataset_utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "captured_function",
srcs = ["captured_function.cc"],
@@ -205,6 +215,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -232,6 +243,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -245,6 +257,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -285,6 +298,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
":parallel_map_iterator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e10833f525..a40f7f2146 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -15,10 +15,57 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace data {
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices) {
+ FunctionLibraryRuntime::Handle fn_handle;
+ TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
+ func.name(), AttrSlice(&func.attr()), &fn_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
+ Status s = ctx->function_library()->ReleaseHandle(fn_handle);
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to release handle: " << s.error_message();
+ }
+ });
+
+ const FunctionBody* fn_body =
+ ctx->function_library()->GetFunctionBody(fn_handle);
+ indices->resize(fn_body->ret_nodes.size());
+ for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
+ Node* ret_node = fn_body->ret_nodes[i];
+ Node* ret_input_node;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
+ if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
+ } else {
+ indices->clear();
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
+ std::map<int, int> last_use;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ last_use[indices[i]] = i;
+ }
+ std::vector<bool> can_move;
+ can_move.resize(indices.size());
+ for (size_t i = 0; i < indices.size(); ++i) {
+ can_move[i] = last_use[indices[i]] == i;
+ }
+ return can_move;
+}
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6ec1350cd4..d777062293 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -22,6 +22,26 @@ limitations under the License.
namespace tensorflow {
namespace data {
+// This method is used to determine whether we can short-circuit the evaluation
+// of the user-defined function `func`. Short-circuting is possible if every
+// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) =
+// (y,x)`, or `f(x) = (x,x)`).
+//
+// If short-circuiting is possible, the method stores the mapping from output
+// indices to input indices in `indices`. Otherwise, `indices` will be empty.
+//
+// Returns non-ok status if analysis of the function fails.
+//
+// TODO(jsimsa): Extend this to support constants as well.
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices);
+
+// Given a vector that maps output indices to input indices, return a vector
+// that identifies for which output indices can we move the input (assuming
+// output indices are processed left to right).
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices);
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
new file mode 100644
index 0000000000..43295b8ebb
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_utils_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+TEST(DatasetUtils, ComputeMoveVector) {
+ struct TestCase {
+ std::vector<int> indices;
+ std::vector<bool> expected;
+ };
+
+ TestCase test_cases[] = {
+ TestCase{{}, {}},
+ TestCase{{1}, {true}},
+ TestCase{{1, 1}, {false, true}},
+ TestCase{{1, 2}, {true, true}},
+ TestCase{{1, 1, 2}, {false, true, true}},
+ TestCase{{1, 2, 2}, {true, false, true}},
+ };
+
+ for (auto& test_case : test_cases) {
+ EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices));
+ }
+}
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 00884314a9..be7d182a1f 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -31,67 +33,84 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
+ using FilterIteratorPredicate =
+ std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
+
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- FunctionLibraryRuntime::Handle pred_handle;
- OP_REQUIRES_OK(ctx,
- ctx->function_library()->Instantiate(
- func_.name(), AttrSlice(&func_.attr()), &pred_handle));
- auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
- OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
- });
-
- const FunctionBody* pred_body =
- ctx->function_library()->GetFunctionBody(pred_handle);
- OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
- errors::InvalidArgument(
- "predicate function must have a single return value."));
- Node* ret_node = pred_body->ret_nodes[0];
- Node* ret_input_node;
- OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
-
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- if (ret_input_node->def().op() == "_Arg") {
- int32 index = -1;
- OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
- *output = new FilterTensorDataset(ctx, input, func_,
- std::move(captured_func), index);
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+ OP_REQUIRES(ctx, indices.size() <= 1,
+ errors::InvalidArgument(
+ "predicate function has more than one return value."));
+
+ FilterIteratorPredicate filter_pred;
+ if (indices.empty()) {
+ CapturedFunction* raw_captured_func = captured_func.get();
+ filter_pred = [raw_captured_func](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ };
} else {
- *output = new FilterFunctionDataset(ctx, input, func_,
- std::move(captured_func));
+ filter_pred = [indices](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ const Tensor& predicate = args[indices[0]];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ };
}
+
+ *output = new Dataset(ctx, input, func_, std::move(captured_func),
+ std::move(filter_pred));
}
private:
- const int graph_def_version_;
-
- class FilterDatasetBase : public DatasetBase {
+ class Dataset : public DatasetBase {
public:
- FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ FilterIteratorPredicate filter_pred)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ filter_pred_(std::move(filter_pred)) {
input_->Ref();
}
- ~FilterDatasetBase() override { input_->Unref(); }
+ ~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Filter")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
+ filter_pred_);
}
const DataTypeVector& output_dtypes() const override {
@@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- virtual Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const = 0;
-
private:
- class Iterator : public DatasetIterator<FilterDatasetBase> {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
+ explicit Iterator(const Params& params,
+ FilterIteratorPredicate filter_pred)
+ : DatasetIterator<Dataset>(params),
filtered_elements_(0),
- dropped_elements_(0) {
+ dropped_elements_(0),
+ filter_pred_(std::move(filter_pred)) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(
- dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
+ const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
-
- protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- };
-
- class FilterFunctionDataset : public FilterDatasetBase {
- public:
- using FilterDatasetBase::FilterDatasetBase;
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- captured_func_->RunWithBorrowedArgs(ctx, element, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- }
- };
-
- class FilterTensorDataset : public FilterDatasetBase {
- public:
- FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- int32 index)
- : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
- index_(index) {}
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- const Tensor& predicate = element[index_];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- }
-
- private:
- const int32 index_;
+ const FilterIteratorPredicate filter_pred_;
};
private:
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 7a833668ac..8acd6cc724 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -16,10 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
-#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
-#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
@@ -27,13 +25,11 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace data {
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index bf08970560..0fb721cd7c 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -41,6 +43,10 @@ namespace {
// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapAndBatchIteratorFunction =
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
+ std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;
+
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
@@ -91,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
- drop_remainder, output_types_, output_shapes_, func_,
- std::move(captured_func), &ctx->eigen_cpu_device());
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapAndBatchIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
+ std::move(done), prefix);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
+ *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_,
+ std::move(captured_func), &ctx->eigen_cpu_device(),
+ std::move(map_func));
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
- const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
- const Eigen::ThreadPoolDevice* device)
+ const Eigen::ThreadPoolDevice* device,
+ MapAndBatchIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ func_(func),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
- map_fn_(func),
captured_func_(std::move(captured_func)),
- device_(device) {
+ device_(device),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -123,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
+ map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -143,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -165,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
- b->BuildAttrValue(map_fn_, &f);
+ b->BuildAttrValue(func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
@@ -185,12 +234,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
+ explicit Iterator(const Params& params,
+ MapAndBatchIteratorFunction map_func)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
- params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
+ map_func_(std::move(map_func)) {}
~Iterator() override {
mutex_lock l(*mu_);
@@ -297,44 +348,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls; // access guarded by owner's mutex
};
- void Callback(const std::shared_ptr<IteratorContext>& ctx,
- const std::shared_ptr<BatchResult>& result,
- const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
- result->UpdateStatus(status);
- if (status.ok()) {
- EnsureOutputAllocated(ctx, result, return_values);
- for (size_t i = 0; i < return_values->size(); ++i) {
- const Tensor& tensor = return_values->at(i);
- Tensor* batch = &(result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does not "
- "match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that allows us to
- // move `tensor` where possible, to speed up string tensor batching.
- Status copy_status = ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
- break;
- }
- }
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
- }
- CallCompleted(result);
- }
-
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
@@ -363,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- // Call `captured_func_(input_element)`, using `Callback` to store the
- // result in `result`.
- (*ctx->runner())(std::bind(
- [this, result, offset](std::shared_ptr<IteratorContext> ctx,
- std::vector<Tensor> input_element) {
- std::shared_ptr<std::vector<Tensor>> return_values(
- new std::vector<Tensor>());
- dataset()->captured_func_->RunAsync(
- ctx.get(), std::move(input_element), return_values.get(),
- [this, ctx, result, return_values, offset](Status status) {
- Callback(ctx, result, return_values, offset, status);
- },
- prefix());
- },
- ctx, std::move(input_element)));
+ std::shared_ptr<std::vector<Tensor>> return_values =
+ std::make_shared<std::vector<Tensor>>();
+ auto done = [this, ctx, result, return_values, offset](Status status) {
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor
+ // batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ }
+ CallCompleted(result);
+ };
+
+ // Apply the map function on `input_element`, storing the result in
+ // `return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ std::move(return_values), std::move(done));
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -404,10 +444,11 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- std::bind(&Iterator::RunnerThread, this, ctx_copy)));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ runner_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
+ runner_thread_->Schedule(
+ std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}
@@ -509,8 +550,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.emplace_back(
- new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
}
int64 offset = call_counter_++ % dataset()->batch_size_;
new_calls.emplace_back(batch_results_.back(), offset);
@@ -527,7 +568,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
- batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -653,6 +695,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
+ const MapAndBatchIteratorFunction map_func_;
+
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
@@ -660,7 +704,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
};
@@ -671,9 +715,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
- const NameAttrList map_fn_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
+ const MapAndBatchIteratorFunction map_func_;
};
const int op_version_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index f112e1dc43..6b6ffabf4f 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -28,6 +30,9 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapIteratorFunction = std::function<Status(
+ IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
+
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ return raw_captured_func->Run(ctx, std::move(args), out_tensors);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ return Status::OK();
+ };
+ }
+
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_);
+ output_types_, output_shapes_, std::move(map_func));
}
private:
@@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
+ const std::vector<PartialTensorShape>& output_shapes,
+ MapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes) {
+ output_shapes_(output_shapes),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Map")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ explicit Iterator(const Params& params, MapIteratorFunction map_func)
+ : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- Status s =
- dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
+ Status s = map_func_(ctx, args, out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
+ const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index 9aa505f4f1..859df57962 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!optimize_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- optimize_thread_.reset(ctx->env()->StartThread(
- {}, "optimize_thread",
- [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ optimize_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "optimize_thread");
+ optimize_thread_->Schedule(
+ [this, new_ctx]() { OptimizeThread(new_ctx); });
}
return Status::OK();
}
@@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
mutex mu_;
condition_variable cond_var_;
std::shared_ptr<model::Model> model_;
- std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<BackgroundWorker> optimize_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 6b6b3d6ab9..9c836b836e 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- worker_threads_.emplace_back(ctx->env()->StartThread(
- {}, "worker_thread",
- [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
+ worker_threads_.emplace_back(
+ MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
+ worker_threads_.back()->Schedule(
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); });
}
}
return Status::OK();
@@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
workers_[i].SetInputs(s, std::move(args));
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- worker_threads_.emplace_back(ctx->env()->StartThread(
- {}, "worker_thread",
- [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
+ worker_threads_.emplace_back(
+ MakeUnique<BackgroundWorker>(ctx->env(), "worker_thread"));
+ worker_threads_.back()->Schedule(
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); });
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
- std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<BackgroundWorker>> worker_threads_
+ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
@@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- [this, new_ctx]() { RunnerThread(new_ctx); }));
+ runner_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
+ runner_thread_->Schedule(
+ [this, new_ctx]() { RunnerThread(new_ctx); });
}
}
@@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
bool cancelled_ GUARDED_BY(*mu_) = false;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 6abe6c8338..3a14924fba 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ ParallelMapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors,
+ std::move(done), prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args),
+ out_tensors, std::move(done)));
+ };
+ }
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args, std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
- std::move(captured_func));
+ std::move(captured_func), std::move(map_func));
}
private:
@@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction> captured_func)
+ std::unique_ptr<CapturedFunction> captured_func,
+ ParallelMapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
@@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes),
use_inter_op_parallelism_(use_inter_op_parallelism),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
- ParallelMapIteratorFunction map_func =
- [this, new_prefix](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done), new_prefix);
- };
- if (!use_inter_op_parallelism_) {
- map_func = [map_func](
- IteratorContext* ctx, std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
- result, std::move(done)));
- };
- }
-
- return NewParallelMapIterator({this, new_prefix}, input_,
- std::move(init_func), std::move(map_func),
- num_parallel_calls_);
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
+ std::move(init_func), map_func_, num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
+ const ParallelMapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 13bd4b6036..e69274e4f2 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -179,10 +180,11 @@ class ParallelMapIterator : public DatasetBaseIterator {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
+ runner_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "runner_thread");
+ runner_thread_->Schedule(
+ std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
}
}
@@ -208,15 +210,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in `result->return_values`,
- // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- map_func_(ctx.get(), std::move(input_element), &result->return_values,
- std::move(done));
+ // Apply the map function on `input_element`, storing the result in
+ // `result->return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ &result->return_values, std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
GUARDED_BY(*mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ std::unique_ptr<BackgroundWorker> runner_thread_ GUARDED_BY(*mu_);
bool cancelled_ GUARDED_BY(*mu_) = false;
};
@@ -349,9 +351,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(
- new ParallelMapIterator(params, input_dataset, std::move(init_func),
- std::move(map_func), num_parallel_calls));
+ return MakeUnique<ParallelMapIterator>(
+ params, input_dataset, std::move(init_func), std::move(map_func),
+ num_parallel_calls);
}
} // namespace data
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index dc26c5cf25..813f13c9e4 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -30,7 +30,7 @@ namespace data {
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
using ParallelMapIteratorFunction =
- std::function<void(IteratorContext*, std::vector<Tensor>,
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 1d1a717062..7de5ea8860 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_fn = [this](IteratorContext* ctx,
+ auto map_fn = [this](IteratorContext* ctx, const string& prefix,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
(*ctx->runner())([this, ctx, input_element, result, done]() {
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 754ed772db..e9c38eb8a0 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
+ prefetch_thread_ =
+ MakeUnique<BackgroundWorker>(ctx->env(), "prefetch_thread");
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
- prefetch_thread_.reset(ctx->env()->StartThread(
- {}, "prefetch_thread",
- [this, new_ctx]() { PrefetchThread(new_ctx); }));
+ prefetch_thread_->Schedule(
+ [this, new_ctx]() { PrefetchThread(new_ctx); });
}
return Status::OK();
}
@@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<BackgroundWorker> prefetch_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 66466d6a36..9f54c381a9 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -485,7 +485,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
int64 buffer_size, int64 seed, int64 seed2, int64 count)
: ShuffleDatasetBase(ctx, input, buffer_size, count),
seed_(seed),
- seed2_(seed) {}
+ seed2_(seed2) {}
string DebugString() const override {
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 3f76695bb1..7bb2077b62 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
public:
explicit ToTFRecordOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
- strings::StrCat("to_tf_record__op_", SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)) {}
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("to_tf_record_op_", SanitizeThreadSuffix(name()))) {
+ }
template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx,
@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
- thread_pool_->Schedule([this, ctx, done]() {
+ background_worker_.Schedule([this, ctx, done]() {
string filename;
OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<string>(ctx, "filename", &filename), done);
@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
}
private:
- std::unique_ptr<thread::ThreadPool> thread_pool_;
+ BackgroundWorker background_worker_;
};
REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 04a53697c0..3810d817ca 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
TF_CALL_half(REGISTER);
@@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<int32>("T") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
TF_CALL_half(REGISTER);
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index 173fea37ed..e67695d54a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
-#define REGISTER_RELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluGradOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6Op<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6GradOp<CPUDevice, type>)
+#define REGISTER_RELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6Op<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<CPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
@@ -99,6 +105,19 @@ namespace functor {
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
+ void LeakyRelu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct LeakyRelu<GPUDevice, T>; \
+ \
+ template <> \
+ void LeakyReluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct LeakyReluGrad<GPUDevice, T>; \
+ \
+ template <> \
void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
@@ -134,30 +153,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6Op<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6GradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- SeluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6Op<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
@@ -188,30 +213,36 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
// Registration of the GPU implementations.
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6Op<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6GradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- SeluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6Op<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6GradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluGradOp<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 4775deeb61..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
}
template <typename Device, typename T>
+class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
+ public:
+ explicit LeakyReluOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::LeakyRelu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
+ output->flat<T>());
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+class LeakyReluGradOp
+ : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
+ public:
+ explicit LeakyReluGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, T alpha, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, alpha_, output);
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a, T alpha,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::LeakyReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
+ output->flat<T>());
+};
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index e564da335a..f917142a12 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -91,6 +91,36 @@ struct Relu6Grad {
}
};
+// Functor used by LeakyReluOp to do the computations.
+template <typename Device, typename T>
+struct LeakyRelu {
+ // Computes LeakyRelu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ T alpha, typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(features * alpha);
+ }
+};
+
+// Functor used by LeakyReluGradOp to do the computations.
+template <typename Device, typename T>
+struct LeakyReluGrad {
+ // Computes LeakyReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the LeakyRelu op.
+ // features: either the inputs that were passed to the LeakyRelu or, or its
+ // outputs (using either one yields the same result here).
+ // backprops: gradients to backpropagate to the LeakyRelu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features, T alpha,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ (features > static_cast<T>(0)).select(gradients, gradients * alpha);
+ }
+};
+
// Functor used by EluOp to do the computations.
template <typename Device, typename T>
struct Elu {
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index b9391517c1..dd5f9495e2 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -145,14 +145,16 @@ struct Relu<Device, qint8> {
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::Relu<GPUDevice, T>; \
- template struct functor::ReluGrad<GPUDevice, T>; \
- template struct functor::Relu6<GPUDevice, T>; \
- template struct functor::Relu6Grad<GPUDevice, T>; \
- template struct functor::Elu<GPUDevice, T>; \
- template struct functor::EluGrad<GPUDevice, T>; \
- template struct functor::Selu<GPUDevice, T>; \
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Relu<GPUDevice, T>; \
+ template struct functor::ReluGrad<GPUDevice, T>; \
+ template struct functor::Relu6<GPUDevice, T>; \
+ template struct functor::Relu6Grad<GPUDevice, T>; \
+ template struct functor::LeakyRelu<GPUDevice, T>; \
+ template struct functor::LeakyReluGrad<GPUDevice, T>; \
+ template struct functor::Elu<GPUDevice, T>; \
+ template struct functor::EluGrad<GPUDevice, T>; \
+ template struct functor::Selu<GPUDevice, T>; \
template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index eab176c7fb..925f5291a6 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- CPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+template <typename Device, typename IntType>
+class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
+ public:
+ using StatelessRandomOpBase::StatelessRandomOpBase;
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+ void Fill(OpKernelContext* context, random::PhiloxRandom random,
+ Tensor* output) override {
+ const Tensor& minval = context->input(2);
+ const Tensor& maxval = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
+ errors::InvalidArgument("minval must be 0-D, got shape ",
+ minval.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
+ errors::InvalidArgument("maxval must be 0-D, got shape ",
+ maxval.shape().DebugString()));
+
+ // Verify that minval < maxval. Note that we'll never reach this point for
+ // empty output. Zero impossible things are fine.
+ const auto lo = minval.scalar<IntType>()();
+ const auto hi = maxval.scalar<IntType>()();
+ OP_REQUIRES(
+ context, lo < hi,
+ errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
+
+ // Build distribution
+ typedef random::UniformDistribution<random::PhiloxRandom, IntType>
+ Distribution;
+ Distribution dist(lo, hi);
+
+ auto flat = output->flat<IntType>();
+ // Reuse the compute kernels from the stateful random ops
+ functor::FillPhiloxRandom<Device, Distribution>()(
+ context, context->eigen_device<Device>(), random, flat.data(),
+ flat.size(), dist);
+ }
+};
-#undef REGISTER
+#define REGISTER(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomUniform") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessTruncatedNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp< \
+ DEVICE##Device, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+
+#define REGISTER_INT(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
+
+#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
+#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
+#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
+#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
+
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_int32(REGISTER_INT_CPU);
+TF_CALL_int64(REGISTER_INT_CPU);
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- GPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+TF_CALL_int32(REGISTER_INT_GPU);
+TF_CALL_int64(REGISTER_INT_GPU);
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+#endif // GOOGLE_CUDA
#undef REGISTER
-
-#endif // GOOGLE_CUDA
+#undef REGISTER_INT
+#undef REGISTER_CPU
+#undef REGISTER_GPU
+#undef REGISTER_INT_CPU
+#undef REGISTER_INT_GPU
} // namespace
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 3559baa18e..3bdcfc90b8 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -108,7 +108,7 @@ class UniqueOp : public OpKernel {
std::unordered_map<T, TIndex> uniq;
uniq.reserve(2 * N);
- for (int64 i = 0, j = 0; i < N; ++i) {
+ for (Eigen::Index i = 0, j = 0; i < N; ++i) {
auto it = uniq.insert(std::make_pair(Tin(i), j));
idx_vec(i) = it.first->second;
if (it.second) {
@@ -131,19 +131,20 @@ class UniqueOp : public OpKernel {
// General implementation when unique is run over multiple elements.
auto Tin = input.shaped<T, 3>(new_sizes);
- auto hash_fn = [&Tin](const int64& key) {
+ auto hash_fn = [&Tin](const Eigen::Index& key) {
size_t h = 0;
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
}
}
return h;
};
- auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ auto equal_to_fn = [&Tin](const Eigen::Index& lhs,
+ const Eigen::Index& rhs) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
return false;
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 780c6f6448..9df0ece69b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -28981,6 +28981,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
@@ -70897,6 +70965,62 @@ op {
}
}
op {
+ name: "StatelessRandomNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessRandomUniform"
input_arg {
name: "shape"
@@ -70994,6 +71118,118 @@ op {
}
}
op {
+ name: "StatelessRandomUniform"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -71091,6 +71327,62 @@ op {
}
}
op {
+ name: "StatelessTruncatedNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessWhile"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3eff728f03..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d1d81b27cc..a9ca69ad86 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+REGISTER_OP("LeakyRelu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("LeakyReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 0d8997c1bd..2048ad26ac 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14296,6 +14296,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
@@ -32978,6 +33046,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33033,6 +33102,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33066,6 +33136,62 @@ op {
}
}
op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -33088,6 +33214,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index adc9cd1486..65bdde375b 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp")
Status VariableShapeShapeFn(InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
- return errors::InvalidArgument("Handle doesn't have shape information.");
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
}
ShapeHandle var_shape = (*handle_data)[0].shape;
int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 742709fb18..f919a21d60 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -19,42 +19,55 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-static Status StatelessShape(shape_inference::InferenceContext* context) {
+static Status StatelessShape(InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
- TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused;
- TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shape
ShapeHandle out;
- TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
- context->set_output(0, out);
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
return Status::OK();
}
-#define REGISTER_STATELESS_OP(name) \
- REGISTER_OP(name) \
- .Input("shape: T") \
- .Input("seed: Tseed") \
- .Output("output: dtype") \
- .Attr("dtype: {half,float,double} = DT_FLOAT") \
- .Attr("T: {int32, int64} = DT_INT32") \
- .Attr("Tseed: {int32, int64} = DT_INT64") \
+#define REGISTER_STATELESS_OP(name) \
+ REGISTER_OP(name) \
+ .Input("shape: T") \
+ .Input("seed: Tseed") \
+ .Output("output: dtype") \
+ .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
+ .Attr("T: {int32, int64} = DT_INT32") \
+ .Attr("Tseed: {int32, int64} = DT_INT64") \
.SetShapeFn(StatelessShape)
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomUniform");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomNormal");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
-// This op is exposed through contrib/stateless only. The interface may change.
+#undef REGISTER_STATELESS_OP
+
+REGISTER_OP("StatelessRandomUniformInt")
+ .Input("shape: T")
+ .Input("seed: Tseed")
+ .Input("minval: dtype")
+ .Input("maxval: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: {int32, int64}")
+ .Attr("T: {int32, int64}")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return StatelessShape(c);
+ });
+
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")
@@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK();
});
-#undef REGISTER_STATELESS_OP
-
} // namespace tensorflow
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 8c31468ff5..7ccd54b818 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -83,6 +83,10 @@ message RewriterConfig {
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
NumIterationsType meta_optimizer_iterations = 12;
+ // Maximum number of milliseconds to spend optimizing a single graph before
+ // timing out. If equal to 0 the system picks a default (currently 5 minutes).
+ // If less than 0 the optimizer will never time out.
+ int64 meta_optimizer_timeout_ms = 20;
// The minimum number of nodes in a graph to optimizer. For smaller graphs,
// optimization is skipped.
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index a7bbb80c82..fe99915a6c 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -7221,6 +7221,45 @@ func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.
return components
}
+// Deprecated. Use TensorArrayGradV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3
+func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorArrayWriteV2",
+ Input: []tf.Input{
+ handle, index, value, flow_in,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Writes the given dataset to the given file using the TFRecord format.
+//
+// Arguments:
+// input_dataset: A variant tensor representing the dataset to write.
+// filename: A scalar string tensor representing the filename to use.
+// compression_type: A scalar string tensor containing either (i) the empty string (no
+// compression), (ii) "ZLIB", or (iii) "GZIP".
+//
+// Returns the created operation.
+func DatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DatasetToTFRecord",
+ Input: []tf.Input{
+ input_dataset, filename, compression_type,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Computes rectified linear 6: `min(max(features, 0), 6)`.
func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil {
@@ -8251,44 +8290,6 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt
return op.Output(0)
}
-// Bucketizes 'input' based on 'boundaries'.
-//
-// For example, if the inputs are
-// boundaries = [0, 10, 100]
-// input = [[-5, 10000]
-// [150, 10]
-// [5, 100]]
-//
-// then the output will be
-// output = [[0, 3]
-// [3, 2]
-// [1, 3]]
-//
-// Arguments:
-// input: Any shape of Tensor contains with int or float type.
-// boundaries: A sorted list of floats gives the boundary of the buckets.
-//
-// Returns Same shape with 'input', each value of input replaced with bucket index.
-//
-// @compatibility(numpy)
-// Equivalent to np.digitize.
-// @end_compatibility
-func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"boundaries": boundaries}
- opspec := tf.OpSpec{
- Type: "Bucketize",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2.
type FusedBatchNormV2Attr func(optionalAttr)
@@ -9640,36 +9641,6 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...
return op.Output(0)
}
-// Returns the element-wise sum of a list of tensors.
-//
-// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
-// wait for all of its inputs to be ready before beginning to sum. This can
-// save memory if inputs are ready at different times, since minimum temporary
-// storage is proportional to the output size rather than the inputs size.
-//
-// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
-//
-// Returns a `Tensor` of same shape and type as the elements of `inputs`.
-//
-// Arguments:
-// inputs: A list of `Tensor` objects, each with same shape and type.
-// shape: Shape of elements of `inputs`.
-func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"shape": shape}
- opspec := tf.OpSpec{
- Type: "AccumulateNV2",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// RandomShuffleAttr is an optional argument to RandomShuffle.
type RandomShuffleAttr func(optionalAttr)
@@ -10383,206 +10354,65 @@ func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.
return scope.AddOperation(opspec)
}
-// Encode audio data using the WAV file format.
-//
-// This operation will generate a string suitable to be saved out to create a .wav
-// audio file. It will be encoded in the 16-bit PCM format. It takes in float
-// values in the range -1.0f to 1.0f, and any outside that value will be clamped to
-// that range.
-//
-// `audio` is a 2-D float Tensor of shape `[length, channels]`.
-// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
-//
-// Arguments:
-// audio: 2-D with shape `[length, channels]`.
-// sample_rate: Scalar containing the sample frequency.
-//
-// Returns 0-D. WAV-encoded file contents.
-func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "EncodeWav",
- Input: []tf.Input{
- audio, sample_rate,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes atan of x element-wise.
-func Atan(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Atan",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
-type ResourceApplyAdaMaxAttr func(optionalAttr)
-
-// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, m, and v tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the AdaMax algorithm.
+// Locks a mutex resource. The output is the lock. So long as the lock tensor
//
-// m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-// v_t <- max(beta2 * v_{t-1}, abs(g))
-// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
+// is alive, any other request to use `MutexLock` with this mutex will wait.
//
-// Arguments:
-// var_: Should be from a Variable().
-// m: Should be from a Variable().
-// v: Should be from a Variable().
-// beta1_power: Must be a scalar.
-// lr: Scaling factor. Must be a scalar.
-// beta1: Momentum factor. Must be a scalar.
-// beta2: Momentum factor. Must be a scalar.
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
+// This is particularly useful for creating a critical section when used in
+// conjunction with `MutexLockIdentity`:
//
-// Returns the created operation.
-func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyAdaMax",
- Input: []tf.Input{
- var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// AssertAttr is an optional argument to Assert.
-type AssertAttr func(optionalAttr)
-
-// AssertSummarize sets the optional summarize attribute to value.
+// ```python
//
-// value: Print this many entries of each tensor.
-// If not specified, defaults to 3
-func AssertSummarize(value int64) AssertAttr {
- return func(m optionalAttr) {
- m["summarize"] = value
- }
-}
-
-// Asserts that the given condition is true.
+// mutex = mutex_v2(
+// shared_name=handle_name, container=container, name=name)
//
-// If `condition` evaluates to false, print the list of tensors in `data`.
-// `summarize` determines how many entries of the tensors to print.
+// def execute_in_critical_section(fn, *args, **kwargs):
+// lock = gen_resource_variable_ops.mutex_lock(mutex)
//
-// Arguments:
-// condition: The condition to evaluate.
-// data: The tensors to print out when condition is false.
+// with ops.control_dependencies([lock]):
+// r = fn(*args, **kwargs)
//
-// Returns the created operation.
-func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Assert",
- Input: []tf.Input{
- condition, tf.OutputList(data),
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// Split a `SparseTensor` into `num_split` tensors along one dimension.
+// with ops.control_dependencies(nest.flatten(r)):
+// with ops.colocate_with(mutex):
+// ensure_lock_exists = mutex_lock_identity(lock)
//
-// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
-// `[0 : shape[split_dim] % num_split]` gets one extra dimension.
-// For example, if `split_dim = 1` and `num_split = 2` and the input is
+// # Make sure that if any element of r is accessed, all of
+// # them are executed together.
+// r = nest.map_structure(tf.identity, r)
//
-// input_tensor = shape = [2, 7]
-// [ a d e ]
-// [b c ]
+// with ops.control_dependencies([ensure_lock_exists]):
+// return nest.map_structure(tf.identity, r)
+// ```
//
-// Graphically the output tensors are:
+// While `fn` is running in the critical section, no other functions which wish to
+// use this critical section may run.
//
-// output_tensor[0] = shape = [2, 4]
-// [ a ]
-// [b c ]
+// Often the use case is that two executions of the same graph, in parallel,
+// wish to run `fn`; and we wish to ensure that only one of them executes
+// at a time. This is especially important if `fn` modifies one or more
+// variables at a time.
//
-// output_tensor[1] = shape = [2, 3]
-// [ d e ]
-// [ ]
+// It is also useful if two separate functions must share a resource, but we
+// wish to ensure the usage is exclusive.
//
// Arguments:
-// split_dim: 0-D. The dimension along which to split. Must be in the range
-// `[0, rank(shape))`.
-// indices: 2-D tensor represents the indices of the sparse tensor.
-// values: 1-D tensor represents the values of the sparse tensor.
-// shape: 1-D. tensor represents the shape of the sparse tensor.
-// output indices: A list of 1-D tensors represents the indices of the output
-// sparse tensors.
-// num_split: The number of ways to split.
+// mutex: The mutex resource to lock.
//
-// Returns A list of 1-D tensors represents the values of the output sparse
-// tensors.A list of 1-D tensors represents the shape of the output sparse
-// tensors.
-func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) {
+// Returns A tensor that keeps a shared pointer to a lock on the mutex;
+// when the Tensor is destroyed, the use count on the shared pointer is decreased
+// by 1. When it reaches 0, the lock is released.
+func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"num_split": num_split}
opspec := tf.OpSpec{
- Type: "SparseSplit",
+ Type: "MutexLock",
Input: []tf.Input{
- split_dim, indices, values, shape,
+ mutex,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil {
- scope.UpdateErr("SparseSplit", err)
- return
- }
- if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil {
- scope.UpdateErr("SparseSplit", err)
- return
- }
- if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil {
- scope.UpdateErr("SparseSplit", err)
- return
- }
- return output_indices, output_values, output_shape
+ return op.Output(0)
}
// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2.
@@ -11151,6 +10981,44 @@ func Tan(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Bucketizes 'input' based on 'boundaries'.
+//
+// For example, if the inputs are
+// boundaries = [0, 10, 100]
+// input = [[-5, 10000]
+// [150, 10]
+// [5, 100]]
+//
+// then the output will be
+// output = [[0, 3]
+// [3, 2]
+// [1, 3]]
+//
+// Arguments:
+// input: Any shape of Tensor contains with int or float type.
+// boundaries: A sorted list of floats gives the boundary of the buckets.
+//
+// Returns Same shape with 'input', each value of input replaced with bucket index.
+//
+// @compatibility(numpy)
+// Equivalent to np.digitize.
+// @end_compatibility
+func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"boundaries": boundaries}
+ opspec := tf.OpSpec{
+ Type: "Bucketize",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// EncodeJpegAttr is an optional argument to EncodeJpeg.
type EncodeJpegAttr func(optionalAttr)
@@ -11699,6 +11567,238 @@ func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output,
return op.Output(0)
}
+// Encode audio data using the WAV file format.
+//
+// This operation will generate a string suitable to be saved out to create a .wav
+// audio file. It will be encoded in the 16-bit PCM format. It takes in float
+// values in the range -1.0f to 1.0f, and any outside that value will be clamped to
+// that range.
+//
+// `audio` is a 2-D float Tensor of shape `[length, channels]`.
+// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
+//
+// Arguments:
+// audio: 2-D with shape `[length, channels]`.
+// sample_rate: Scalar containing the sample frequency.
+//
+// Returns 0-D. WAV-encoded file contents.
+func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeWav",
+ Input: []tf.Input{
+ audio, sample_rate,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes atan of x element-wise.
+func Atan(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Atan",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
+type ResourceApplyAdaMaxAttr func(optionalAttr)
+
+// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, m, and v tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the AdaMax algorithm.
+//
+// m_t <- beta1 * m_{t-1} + (1 - beta1) * g
+// v_t <- max(beta2 * v_{t-1}, abs(g))
+// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
+//
+// Arguments:
+// var_: Should be from a Variable().
+// m: Should be from a Variable().
+// v: Should be from a Variable().
+// beta1_power: Must be a scalar.
+// lr: Scaling factor. Must be a scalar.
+// beta1: Momentum factor. Must be a scalar.
+// beta2: Momentum factor. Must be a scalar.
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+//
+// Returns the created operation.
+func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyAdaMax",
+ Input: []tf.Input{
+ var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// AssertAttr is an optional argument to Assert.
+type AssertAttr func(optionalAttr)
+
+// AssertSummarize sets the optional summarize attribute to value.
+//
+// value: Print this many entries of each tensor.
+// If not specified, defaults to 3
+func AssertSummarize(value int64) AssertAttr {
+ return func(m optionalAttr) {
+ m["summarize"] = value
+ }
+}
+
+// Asserts that the given condition is true.
+//
+// If `condition` evaluates to false, print the list of tensors in `data`.
+// `summarize` determines how many entries of the tensors to print.
+//
+// Arguments:
+// condition: The condition to evaluate.
+// data: The tensors to print out when condition is false.
+//
+// Returns the created operation.
+func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Assert",
+ Input: []tf.Input{
+ condition, tf.OutputList(data),
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Split a `SparseTensor` into `num_split` tensors along one dimension.
+//
+// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
+// `[0 : shape[split_dim] % num_split]` gets one extra dimension.
+// For example, if `split_dim = 1` and `num_split = 2` and the input is
+//
+// input_tensor = shape = [2, 7]
+// [ a d e ]
+// [b c ]
+//
+// Graphically the output tensors are:
+//
+// output_tensor[0] = shape = [2, 4]
+// [ a ]
+// [b c ]
+//
+// output_tensor[1] = shape = [2, 3]
+// [ d e ]
+// [ ]
+//
+// Arguments:
+// split_dim: 0-D. The dimension along which to split. Must be in the range
+// `[0, rank(shape))`.
+// indices: 2-D tensor represents the indices of the sparse tensor.
+// values: 1-D tensor represents the values of the sparse tensor.
+// shape: 1-D. tensor represents the shape of the sparse tensor.
+// output indices: A list of 1-D tensors represents the indices of the output
+// sparse tensors.
+// num_split: The number of ways to split.
+//
+// Returns A list of 1-D tensors represents the values of the output sparse
+// tensors.A list of 1-D tensors represents the shape of the output sparse
+// tensors.
+func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_split": num_split}
+ opspec := tf.OpSpec{
+ Type: "SparseSplit",
+ Input: []tf.Input{
+ split_dim, indices, values, shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil {
+ scope.UpdateErr("SparseSplit", err)
+ return
+ }
+ if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil {
+ scope.UpdateErr("SparseSplit", err)
+ return
+ }
+ if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil {
+ scope.UpdateErr("SparseSplit", err)
+ return
+ }
+ return output_indices, output_values, output_shape
+}
+
+// Returns the element-wise sum of a list of tensors.
+//
+// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
+// wait for all of its inputs to be ready before beginning to sum. This can
+// save memory if inputs are ready at different times, since minimum temporary
+// storage is proportional to the output size rather than the inputs size.
+//
+// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
+//
+// Returns a `Tensor` of same shape and type as the elements of `inputs`.
+//
+// Arguments:
+// inputs: A list of `Tensor` objects, each with same shape and type.
+// shape: Shape of elements of `inputs`.
+func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "AccumulateNV2",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal.
type StatelessTruncatedNormalAttr func(optionalAttr)
@@ -13925,67 +14025,6 @@ func CudnnRNNBackpropV2(scope *Scope, input tf.Output, input_h tf.Output, input_
return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
}
-// Locks a mutex resource. The output is the lock. So long as the lock tensor
-//
-// is alive, any other request to use `MutexLock` with this mutex will wait.
-//
-// This is particularly useful for creating a critical section when used in
-// conjunction with `MutexLockIdentity`:
-//
-// ```python
-//
-// mutex = mutex_v2(
-// shared_name=handle_name, container=container, name=name)
-//
-// def execute_in_critical_section(fn, *args, **kwargs):
-// lock = gen_resource_variable_ops.mutex_lock(mutex)
-//
-// with ops.control_dependencies([lock]):
-// r = fn(*args, **kwargs)
-//
-// with ops.control_dependencies(nest.flatten(r)):
-// with ops.colocate_with(mutex):
-// ensure_lock_exists = mutex_lock_identity(lock)
-//
-// # Make sure that if any element of r is accessed, all of
-// # them are executed together.
-// r = nest.map_structure(tf.identity, r)
-//
-// with ops.control_dependencies([ensure_lock_exists]):
-// return nest.map_structure(tf.identity, r)
-// ```
-//
-// While `fn` is running in the critical section, no other functions which wish to
-// use this critical section may run.
-//
-// Often the use case is that two executions of the same graph, in parallel,
-// wish to run `fn`; and we wish to ensure that only one of them executes
-// at a time. This is especially important if `fn` modifies one or more
-// variables at a time.
-//
-// It is also useful if two separate functions must share a resource, but we
-// wish to ensure the usage is exclusive.
-//
-// Arguments:
-// mutex: The mutex resource to lock.
-//
-// Returns A tensor that keeps a shared pointer to a lock on the mutex;
-// when the Tensor is destroyed, the use count on the shared pointer is decreased
-// by 1. When it reaches 0, the lock is released.
-func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MutexLock",
- Input: []tf.Input{
- mutex,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// StringFormatAttr is an optional argument to StringFormat.
type StringFormatAttr func(optionalAttr)
@@ -16807,26 +16846,6 @@ func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values
return op.Output(0), op.Output(1)
}
-// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
-//
-// The Hurwitz zeta function is defined as:
-//
-//
-// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\)
-func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Zeta",
- Input: []tf.Input{
- x, q,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns a list of tensors with the same shapes and contents as the input
//
// tensors.
@@ -18873,6 +18892,26 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D
return op.Output(0)
}
+// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
+//
+// The Hurwitz zeta function is defined as:
+//
+//
+// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\)
+func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Zeta",
+ Input: []tf.Input{
+ x, q,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Inverse fast Fourier transform.
//
// Computes the inverse 1-dimensional discrete Fourier transform over the
@@ -21413,43 +21452,6 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min
return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes the minimum along segments of a tensor.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
-//
-// Computes a tensor such that
-// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
-// that `segment_ids[j] == i`.
-//
-// If the min is empty for a given segment ID `i`, `output[i] = 0`.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
-// first dimension. Values should be sorted and can be repeated.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SegmentMin",
- Input: []tf.Input{
- data, segment_ids,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SdcaOptimizerAttr is an optional argument to SdcaOptimizer.
type SdcaOptimizerAttr func(optionalAttr)
@@ -21924,6 +21926,43 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
+// Computes the minimum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// Computes a tensor such that
+// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
+// that `segment_ids[j] == i`.
+//
+// If the min is empty for a given segment ID `i`, `output[i] = 0`.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
+// first dimension. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SegmentMin",
+ Input: []tf.Input{
+ data, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along segments of a tensor.
//
// Read
@@ -22757,6 +22796,21 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output
return op.Output(0)
}
+// Computes hyperbolic tangent of `x` element-wise.
+func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tanh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the maximum along segments of a tensor.
//
// Read
@@ -22794,21 +22848,6 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.
return op.Output(0)
}
-// Computes hyperbolic tangent of `x` element-wise.
-func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tanh",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that skips `count` elements from the `input_dataset`.
//
// Arguments:
@@ -29878,28 +29917,6 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
return op.Output(0)
}
-// Writes the given dataset to the given file using the TFRecord format.
-//
-// Arguments:
-// input_dataset: A variant tensor representing the dataset to write.
-// filename: A scalar string tensor representing the filename to use.
-// compression_type: A scalar string tensor containing either (i) the empty string (no
-// compression), (ii) "ZLIB", or (iii) "GZIP".
-//
-// Returns the created operation.
-func DatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DatasetToTFRecord",
- Input: []tf.Input{
- input_dataset, filename, compression_type,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// AvgPool3DAttr is an optional argument to AvgPool3D.
type AvgPool3DAttr func(optionalAttr)
@@ -31692,23 +31709,6 @@ func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size
return op.Output(0)
}
-// Deprecated. Use TensorArrayGradV3
-//
-// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3
-func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "TensorArrayWriteV2",
- Input: []tf.Input{
- handle, index, value, flow_in,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SparseReduceMaxAttr is an optional argument to SparseReduceMax.
type SparseReduceMaxAttr func(optionalAttr)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index da3c56db92..822d596995 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -5197,6 +5197,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "control_flow_ops_benchmark",
+ srcs = ["ops/control_flow_ops_benchmark.py"],
+ additional_deps = [
+ ":client_testlib",
+ ":constant_op",
+ ":control_flow_ops",
+ ":framework_ops",
+ "//tensorflow/python/eager:function",
+ ],
+ main = "ops/control_flow_ops_benchmark.py",
+)
+
+cuda_py_test(
name = "conv2d_benchmark",
size = "large",
srcs = ["ops/conv2d_benchmark.py"],
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index dc2d419d34..fcdbd0a82c 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -128,7 +128,13 @@ class TestCase(test.TestCase):
@contextlib.contextmanager
def converted(self, entity, converter_module, namespace, *tf_symbols):
node, ctx = self.prepare(entity, namespace)
- node = converter_module.transform(node, ctx)
+
+ if not isinstance(converter_module, (list, tuple)):
+ converter_module = (converter_module,)
+ for m in converter_module:
+ node = m.transform(node, ctx)
+ node = converter.standard_analysis(node, ctx, is_initial=True)
+
with self.compiled(node, namespace, *tf_symbols) as result:
yield result
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index 1416988ea3..29c406c248 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -67,6 +67,40 @@ def getnamespace(f):
return namespace
+def getqualifiedname(namespace, object_, max_depth=2):
+ """Returns the name by which a value can be referred to in a given namespace.
+
+ This function will recurse inside modules, but it will not search objects for
+ attributes. The recursion depth is controlled by max_depth.
+
+ Args:
+ namespace: Dict[str, Any], the namespace to search into.
+ object_: Any, the value to search.
+ max_depth: Optional[int], a limit to the recursion depth when searching
+ inside modules.
+ Returns: Union[str, None], the fully-qualified name that resolves to the value
+ o, or None if it couldn't be found.
+ """
+ for name, value in namespace.items():
+ # The value may be referenced by more than one symbol, case in which
+ # any symbol will be fine. If the program contains symbol aliases that
+ # change over time, this may capture a symbol that will later point to
+ # something else.
+ # TODO(mdan): Prefer the symbol that matches the value type name.
+ if object_ is value:
+ return name
+
+ # TODO(mdan): Use breadth-first search and avoid visiting modules twice.
+ if max_depth:
+ for name, value in namespace.items():
+ if tf_inspect.ismodule(value):
+ name_in_module = getqualifiedname(value.__dict__, object_,
+ max_depth - 1)
+ if name_in_module is not None:
+ return '{}.{}'.format(name, name_in_module)
+ return None
+
+
def _get_unbound_function(m):
# TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
# The failure case is for tf.keras.Model.
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index f3eb027822..11074debfc 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
+import imp
import six
@@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase):
self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
self.assertTrue('local_var' not in ns)
+ def test_getqualifiedname(self):
+ foo = object()
+ qux = imp.new_module('quxmodule')
+ bar = imp.new_module('barmodule')
+ baz = object()
+ bar.baz = baz
+
+ ns = {
+ 'foo': foo,
+ 'bar': bar,
+ 'qux': qux,
+ }
+
+ self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
+ self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
+
def test_getmethodclass(self):
self.assertEqual(
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 8f4e8e0b98..349c84e13c 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 5)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 8)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 4eef9580ad..a67f6ff031 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -453,6 +453,18 @@ cuda_py_test(
tags = ["no_windows_gpu"],
)
+py_test(
+ name = "random_dataset_test",
+ srcs = ["random_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/data/experimental/ops:random_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
py_library(
name = "reader_dataset_ops_test_base",
testonly = 1,
@@ -562,6 +574,7 @@ py_test(
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index afd0fc3abf..d444c4082e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -332,6 +332,37 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+ @parameterized.named_parameters(
+ ("Identity", None, lambda x: x, None),
+ ("Replicate", None, lambda x: (x, x), None),
+ ("Swap", (None, None), lambda x, y: (y, x), None),
+ ("Project", (None, None), lambda x, y: x, None),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(
+ *sess.run(self.structuredElement(structure, shape=[10])))
+ else:
+ expected = map_fn(
+ sess.run(self.structuredElement(structure, shape=[10])))
+ self.assertAllEqual(expected, sess.run(get_next))
+
+ def testShortCircuitCapturedInput(self):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().apply(
+ batching.map_and_batch(lambda x: captured_t, batch_size=10))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertAllEqual([42] * 10, sess.run(get_next))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/random_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/random_dataset_test.py
new file mode 100644
index 0000000000..d403a575ec
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/random_dataset_test.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.RandomDataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+
+
+class RandomDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("NoSeed", None),
+ ("WithSeed", 42),
+ )
+ def testZipRandomDataset(self, seed):
+ dataset = random_ops.RandomDataset(seed=seed).take(30)
+ dataset = dataset_ops.Dataset.zip((dataset, dataset))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(30):
+ x, y = sess.run(next_element)
+ self.assertEqual(x, y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
index fe0b3b5f3b..77df8310d4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -64,7 +64,7 @@ class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing `make_batched_feature_dataset`."""
+ """Base class for setting up and testing `make_batched_features_dataset`."""
def setUp(self):
super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
index c208963a86..883169495f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import shuffle_ops
@@ -27,7 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleAndRepeatTest(test_base.DatasetTestBase):
+class ShuffleAndRepeatTest(test_base.DatasetTestBase, parameterized.TestCase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
@@ -110,6 +111,24 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
with self.session(graph=g) as sess:
sess.run(get_next_op)
+ @parameterized.named_parameters(
+ ("NoSeed", None),
+ ("WithSeed", 42),
+ )
+ def testShuffleAndRepeatAndZipDataset(self, seed):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ shuffle_ops.shuffle_and_repeat(10, count=3, seed=seed))
+ dataset = dataset_ops.Dataset.zip((dataset, dataset))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(30):
+ x, y = sess.run(next_element)
+ self.assertEqual(x, y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py
index e3a2aeab31..25d7fbf691 100644
--- a/tensorflow/python/data/experimental/ops/random_ops.py
+++ b/tensorflow/python/data/experimental/ops/random_ops.py
@@ -33,13 +33,26 @@ class RandomDataset(dataset_ops.DatasetSource):
def __init__(self, seed=None):
"""A `Dataset` of pseudorandom values."""
super(RandomDataset, self).__init__()
- self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ # NOTE(mrry): We generate the seed-pair once per graph in which the dataset
+ # is iterated over, and cache it in `self._graph_seed_map`. This supports
+ # two features: iterating over the same `ShuffleDataset` twice in the same
+ # pipeline and observing the same order (by tying the seeds together with
+ # a randomly-generated seed), and using `Dataset.make_one_shot_iterator()`,
+ # which requires the stateful RNG op to be created inside the same graph as
+ # the dataset.
+ self._original_seed = seed
+ self._graph_seed_map = {}
def _as_variant_tensor(self):
+ try:
+ seed, seed2 = self._graph_seed_map[ops.get_default_graph()]
+ except KeyError:
+ seed, seed2 = random_seed.get_seed(self._original_seed)
+ self._graph_seed_map[ops.get_default_graph()] = (seed, seed2)
+
return gen_dataset_ops.random_dataset(
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
+ seed=seed, seed2=seed2, **dataset_ops.flat_structure(self))
@property
def output_classes(self):
diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py
index a4307212da..a82e4b7d09 100644
--- a/tensorflow/python/data/experimental/ops/shuffle_ops.py
+++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py
@@ -39,17 +39,32 @@ class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
else:
self._count = ops.convert_to_tensor(
count, dtype=dtypes.int64, name="count")
- self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ # NOTE(mrry): We generate the seed-pair once per graph in which the dataset
+ # is iterated over, and cache it in `self._graph_seed_map`. This supports
+ # two features: iterating over the same `ShuffleDataset` twice in the same
+ # pipeline and observing the same order (by tying the seeds together with
+ # a randomly-generated seed), and using `Dataset.make_one_shot_iterator()`,
+ # which requires the stateful RNG op to be created inside the same graph as
+ # the dataset.
+ self._original_seed = seed
+ self._graph_seed_map = {}
def _as_variant_tensor(self):
+ try:
+ seed, seed2 = self._graph_seed_map[ops.get_default_graph()]
+ except KeyError:
+ seed, seed2 = random_seed.get_seed(self._original_seed)
+ self._graph_seed_map[ops.get_default_graph()] = (seed, seed2)
+
# pylint: disable=protected-access
input_resource = self._input_dataset._as_variant_tensor()
return gen_dataset_ops.shuffle_and_repeat_dataset(
input_resource,
buffer_size=self._buffer_size,
count=self._count,
- seed=self._seed,
- seed2=self._seed2,
+ seed=seed,
+ seed2=seed2,
**dataset_ops.flat_structure(self))
# pylint: enable=protected-access
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index c7295d6e69..ecb24103b3 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -443,6 +443,7 @@ tf_py_test(
srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 6b7afafa5d..a0c6b37a6d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -156,7 +156,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testReturnComponent(self):
+ def testShortCircuit(self):
iterator = (
dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10),
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 0c372ebb10..4683b1db91 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -622,7 +622,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -649,7 +649,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -783,19 +783,72 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
+ @parameterized.named_parameters(
+ ("SequentialIdentity", None, lambda x: x, None),
+ ("SequentialReplicate", None, lambda x: (x, x), None),
+ ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
+ ("SequentialProject", (None, None), lambda x, y: x, None),
+ ("ParallelIdentity", None, lambda x: x, 10),
+ ("ParallelReplicate", None, lambda x: (x, x), 10),
+ ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
+ ("ParallelProject", (None, None), lambda x, y: x, 10),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().map(
+ map_fn, num_parallel_calls=num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(*sess.run(self.structuredElement(structure)))
+ else:
+ expected = map_fn(sess.run(self.structuredElement(structure)))
+ self.assertEqual(expected, sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Sequential", None),
+ ("Parallel", 10),
+ )
+ def testShortCircuitCapturedInput(self, num_parallel_calls):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().map(
+ lambda x: captured_t, num_parallel_calls=num_parallel_calls)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertEqual(42, sess.run(get_next))
+
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda x: x
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
- lambda x: x,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -813,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset chain length%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", chain_length, median_wall_time))
+ (print_label, chain_length, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ (chain_length, benchmark_label))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda *xs: xs
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
- lambda *xs: xs,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -849,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset fan out%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", fan_out, median_wall_time))
+ (print_label, fan_out, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" %
- (fan_out, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
+ benchmark_label))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 347af18576..6001721726 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import collections
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
@@ -31,7 +32,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test_base.DatasetTestBase):
+class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testShuffleDataset(self):
components = (
@@ -209,5 +210,27 @@ class ShuffleDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ @parameterized.named_parameters(
+ ("ReshuffleEachIterationNoSeed", None, True),
+ ("ReshuffleEachIterationWithSeed", 42, True),
+ ("NoReshuffleEachIterationNoSeed", None, False),
+ ("NoReshuffleEachIterationWithSeed", 42, False),
+ )
+ def testShuffleAndZipDataset(self, seed, reshuffle):
+ dataset = (dataset_ops.Dataset.range(10)
+ .shuffle(10, seed=seed, reshuffle_each_iteration=reshuffle)
+ .repeat(3))
+ dataset = dataset_ops.Dataset.zip((dataset, dataset))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(30):
+ x, y = sess.run(next_element)
+ self.assertEqual(x, y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b730e10949..b73a94e683 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -19,10 +19,13 @@ from __future__ import print_function
import re
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase):
with self.assertRaisesRegexp(exception_class,
re.escape(expected_message)):
self.evaluate(next2())
+
+ def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns a singleton dataset with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self.structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns an element with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self.structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index b7e19055f2..2d036fd0d6 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -2254,18 +2254,34 @@ class ShuffleDataset(UnaryDataset):
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
- self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ # NOTE(mrry): We generate the seed-pair once per graph in which the dataset
+ # is iterated over, and cache it in `self._graph_seed_map`. This supports
+ # two features: iterating over the same `ShuffleDataset` twice in the same
+ # pipeline and observing the same order (by tying the seeds together with
+ # a randomly-generated seed), and using `Dataset.make_one_shot_iterator()`,
+ # which requires the stateful RNG op to be created inside the same graph as
+ # the dataset.
+ self._original_seed = seed
+ self._graph_seed_map = {}
+
if reshuffle_each_iteration is None:
self._reshuffle_each_iteration = True
else:
self._reshuffle_each_iteration = reshuffle_each_iteration
def _as_variant_tensor(self):
+ try:
+ seed, seed2 = self._graph_seed_map[ops.get_default_graph()]
+ except KeyError:
+ seed, seed2 = random_seed.get_seed(self._original_seed)
+ self._graph_seed_map[ops.get_default_graph()] = (seed, seed2)
+
return gen_dataset_ops.shuffle_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
buffer_size=self._buffer_size,
- seed=self._seed,
- seed2=self._seed2,
+ seed=seed,
+ seed2=seed2,
reshuffle_each_iteration=self._reshuffle_each_iteration,
**flat_structure(self))
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 39082ce370..95bf3209d7 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -142,6 +142,7 @@ py_test(
":random_seed",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/data/util/random_seed.py b/tensorflow/python/data/util/random_seed.py
index d5169f7a53..d24df6d957 100644
--- a/tensorflow/python/data/util/random_seed.py
+++ b/tensorflow/python/data/util/random_seed.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
def get_seed(seed):
@@ -37,7 +38,7 @@ def get_seed(seed):
Returns:
A tuple of two `tf.int64` scalar tensors that should be used for the local
- seed of the calling dataset.
+ seeds of the calling dataset.
"""
seed, seed2 = random_seed.get_seed(seed)
if seed is None:
@@ -45,7 +46,7 @@ def get_seed(seed):
else:
seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed")
if seed2 is None:
- seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
+ seed2 = random_ops.random_uniform([], 1, 2**63 - 1, dtype=dtypes.int64)
else:
with ops.name_scope("seed2") as scope:
seed2 = ops.convert_to_tensor(seed2, dtype=dtypes.int64)
diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py
index a809151e6e..5df2e38c62 100644
--- a/tensorflow/python/data/util/random_seed_test.py
+++ b/tensorflow/python/data/util/random_seed_test.py
@@ -41,7 +41,6 @@ class RandomSeedTest(test.TestCase):
# (input_graph_seed, input_op_seed)
# and output from get_seed:
# (output_graph_seed, output_op_seed)
- ((None, None), (0, 0)),
((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)),
((1, 1), (1, 1)),
((0, 0), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output
@@ -78,6 +77,18 @@ class RandomSeedTest(test.TestCase):
self.assertEqual((g_seed, op_seed), toutput, msg=msg)
random_seed.set_random_seed(None)
+ @test_util.run_in_graph_and_eager_modes
+ def testNondeterministicRandomSeed(self):
+ random_seed.set_random_seed(None)
+ op_seeds = []
+ for _ in range(50):
+ g_seed, op_seed = data_random_seed.get_seed(None)
+ g_seed = self.evaluate(g_seed)
+ op_seed = self.evaluate(op_seed)
+ self.assertEqual(0, g_seed)
+ self.assertNotEqual(0, op_seed)
+ op_seeds.append(op_seed)
+ self.assertGreater(len(set(op_seeds)), 1)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d0c1a93118..cae809a7c3 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -251,6 +251,7 @@ py_library(
"//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:execute",
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 2750461fb2..99bf375ea7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -31,6 +31,7 @@ import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
+from tensorflow.python import autograph
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
@@ -45,6 +46,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -80,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- _copy_handle_data(value, placeholder)
+ custom_gradient.copy_handle_data(value, placeholder)
return placeholder
-def _copy_handle_data(source_t, target_t):
- """Copies HandleData for variant and resource type tensors if available.
-
- The CppShapeInferenceResult::HandleData proto contains information about the
- shapes and types of the element tensors of resource/variant type tensors.
- We need to copy this across function boundaries, i.e., when capturing a
- placeholder or when returning a function tensor as output. If we don't do this
- the element tensors will have unknown shapes, e.g., if a TensorList variant
- tensor is captured as a placeholder, elements popped from that list would have
- unknown shape.
-
- Args:
- source_t: The tensor to copy HandleData from.
- target_t: The tensor to copy HandleData to.
- """
- if (target_t.dtype == dtypes_module.resource or
- target_t.dtype == dtypes_module.variant):
- if isinstance(source_t, ops.EagerTensor):
- handle_data = source_t._handle_data # pylint: disable=protected-access
- else:
- handle_data = resource_variable_ops.get_resource_handle_data(source_t)
- if handle_data is not None and handle_data.is_set:
- # pylint: disable=protected-access
- pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
- target_t._as_tf_output(),
- handle_data.SerializeToString())
- # pylint: enable=protected-access
- # Ensure that shapes and dtypes are propagated.
- shapes, types = zip(*[(pair.shape, pair.dtype)
- for pair in handle_data.shape_and_type])
- ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
- shapes = [[d.size for d in s.dim]
- if not s.unknown_rank else None for s in shapes]
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- target_t._op._graph._c_graph, # pylint: disable=protected-access
- target_t._as_tf_output(), # pylint: disable=protected-access
- shapes, ranks, types)
-
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@@ -546,7 +509,7 @@ class _EagerDefinedFunction(object):
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
for i, func_graph_output in enumerate(self._func_graph_outputs):
- _copy_handle_data(func_graph_output, outputs[i])
+ custom_gradient.copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -657,7 +620,48 @@ class Function(object):
if tape.should_record(tensor_inputs) or tape.should_record(captures):
return self._backprop_call(args)
- outputs = self._inference_function.call(ctx, args)
+ # Only need to override the gradient in graph mode and when we have outputs.
+ if context.executing_eagerly() or not self.outputs:
+ outputs = self._inference_function.call(ctx, args)
+ else:
+ name = "PartitionedCall-%s" % ops.uid()
+
+ @ops.RegisterGradient(name)
+ def grad_fn(op, *doutputs): # pylint: disable=unused-variable
+ """Gradients of this function."""
+ if op.graph is not ops.get_default_graph():
+ # TODO(apassos) this will still emit SymbolicGradient ops when
+ # nested defuns are being differentiated. We need to somehow figure
+ # out a way to update the FunctionDef corresponding to the calling
+ # function when mutating a call to the forward pass.
+ return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+ self._forward_function.add_to_graph(op.graph)
+ func = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(
+ name=self._forward_function.name))
+ # pylint: disable=protected-access
+ op._set_attr("f", func)
+ types = attr_value_pb2.AttrValue.ListValue(
+ type=self._forward_function._output_types)
+ op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
+ for i in range(
+ len(outputs), len(self._forward_function._output_types)):
+ t = ops.Tensor(op, i, self._forward_function._output_types[i])
+ t.set_shape(self._forward_function._output_shapes[i])
+ func_graph_output = self._forward_function._func_graph_outputs[i]
+ custom_gradient.copy_handle_data(func_graph_output, t)
+ op._outputs.append(t)
+ # pylint: enable=protected-access
+ side_outputs = op.outputs[len(outputs):]
+ return self._backward_graph_function(
+ *(list(doutputs) + list(side_outputs)))
+
+ with ops.get_default_graph().gradient_override_map(
+ {"PartitionedCall": name}):
+ outputs = self._inference_function.call(ctx, args)
+
return self._build_call_outputs(outputs)
@property
@@ -854,20 +858,12 @@ class Function(object):
return ret
-def _get_defun_inputs_from_signature(signature):
- """Maps a signature to graph-construction inputs."""
- function_inputs = [
- graph_placeholder(spec.dtype, spec.shape)
- for spec in nest.flatten(signature)
- ]
- return nest.pack_sequence_as(signature, function_inputs)
-
-
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
graph_placeholder(arg.dtype, arg.shape)
- if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
+ if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
+ else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
@@ -877,7 +873,8 @@ def func_graph_from_py_func(name,
args,
kwargs,
signature=None,
- func_graph=None):
+ func_graph=None,
+ experimental_autograph=False):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -894,6 +891,8 @@ def func_graph_from_py_func(name,
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
this graph else a new one is built and returned.
+ experimental_autograph: whether to use autograph to compile `python_func`.
+ See https://www.tensorflow.org/guide/autograph for more information.
Returns:
A FuncGraph.
@@ -908,12 +907,12 @@ def func_graph_from_py_func(name,
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
- if signature is None:
- func_args = _get_defun_inputs_from_args(args)
- func_kwargs = _get_defun_inputs_from_args(kwargs)
- else:
- func_args = _get_defun_inputs_from_signature(signature)
- func_kwargs = {}
+ if signature is not None:
+ args = signature
+ kwargs = {}
+
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
@@ -939,7 +938,17 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwargs)
+ if experimental_autograph:
+ func_outputs = autograph.converted_call(
+ python_func,
+ autograph.ConversionOptions(
+ verbose=True,
+ recursive=True,
+ force_conversion=False,
+ strip_decorators=(defun,),
+ arg_types={}), *func_args, **func_kwargs)
+ else:
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -1035,7 +1044,8 @@ class PolymorphicFunction(object):
python_function,
name,
input_signature=None,
- attributes=None):
+ attributes=None,
+ experimental_autograph=False):
"""Initializes a polymorphic function.
Args:
@@ -1045,7 +1055,10 @@ class PolymorphicFunction(object):
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
attributes: dict, extra keyword arguments that will be added as attribute
- of the function.
+ of the function.
+ experimental_autograph: whether to use autograph to compile
+ `python_function`. See https://www.tensorflow.org/guide/autograph for
+ more information.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -1061,6 +1074,7 @@ class PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
self._name = name
+ self._experimental_autograph = experimental_autograph
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1286,8 +1300,13 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
- func_graph_from_py_func(self._name, self._python_function, args,
- kwargs, self._input_signature),
+ func_graph_from_py_func(
+ self._name,
+ self._python_function,
+ args,
+ kwargs,
+ self._input_signature,
+ experimental_autograph=self._experimental_autograph),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
@@ -1348,7 +1367,7 @@ def _validate_signature(signature):
"a possibly nested sequence of TensorSpec objects.")
-def defun(func=None, input_signature=None):
+def defun(func=None, input_signature=None, experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1657,6 +1676,10 @@ def defun(func=None, input_signature=None):
function is instantiated for each inferred input signature. If a
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
+ experimental_autograph: Whether `func` should be compiled before
+ constructing the graph. See https://www.tensorflow.org/guide/autograph
+ for more information.
+
Returns:
If `func` is not None, returns a callable that will execute the compiled
@@ -1668,10 +1691,16 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
- return defun_with_attributes(func=func, input_signature=input_signature)
+ return defun_with_attributes(
+ func=func,
+ input_signature=input_signature,
+ experimental_autograph=experimental_autograph)
-def defun_with_attributes(func=None, input_signature=None, attributes=None):
+def defun_with_attributes(func=None,
+ input_signature=None,
+ attributes=None,
+ experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
@@ -1686,6 +1715,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
attributes. Currently only support primitive types as value, and only
whitelisted attribute name is allowed. Unwhitelisted attribute name or
unsupported value will result into ValueError.
+ experimental_autograph: same as defun()'s experimental_autograph.
Returns:
Same as the return value of defun, with attributes added to the function in
@@ -1702,8 +1732,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature,
- attributes=attributes))
+ PolymorphicFunction(
+ function,
+ name,
+ input_signature=input_signature,
+ attributes=attributes,
+ experimental_autograph=experimental_autograph))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1906,8 +1940,10 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue
+ found_resource = False
for inp in op.inputs:
if inp.dtype == dtypes_module.resource:
+ found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
@@ -1922,6 +1958,11 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
+ if (op.op_def.is_stateful and not found_resource
+ and op._control_flow_context is None): # pylint: disable=protected-access
+ if None in last_op_using_resource_tensor:
+ op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
+ last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a2cfb4b476..e46bde098b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -172,6 +172,43 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def testInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(a):
+ return matmul(a, a)
+
+ sq_op = sq.get_concrete_function(
+ tensor_spec.TensorSpec((None, None), dtypes.float32))
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out1 = sq_op(t1)
+ self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
+
+ t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out2 = sq_op(t2)
+ self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
+
+ def testNestedInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(mats):
+ ((a, b),) = mats
+ return matmul(a, b)
+
+ sq_op = sq.get_concrete_function(
+ [(tensor_spec.TensorSpec((None, None), dtypes.float32),
+ tensor_spec.TensorSpec((None, None), dtypes.float32))])
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
+ out = sq_op(t1, t2) # Flattened structure for inputs to the graph function
+ self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
+
def testExecutingStatelessDefunConcurrently(self):
@function.defun
@@ -249,7 +286,23 @@ class FunctionTest(test.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
- self.assertAllEqual(sess.run(g), [[1.0]])
+ self.assertAllEqual(sess.run(g).values, [[1.0]])
+
+ def testNoSymGradNestedDefun(self):
+
+ @function.defun
+ def outer():
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertTrue(isinstance(g, ops.IndexedSlices))
+
+ outer()
def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f5af4ab6c..5c35860e9d 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -51,11 +51,6 @@ def imperative_grad(
Raises:
RuntimeError: if something goes wrong.
- ValueError: if there is no sequence of differentiable operations connecting
- a source and any target Tensor. This can happen either if the target is
- not computed based on the source, if the tracing was set up incorrectly,
- or if only non-differentiable functions of the source were used in the
- computation of target.
"""
return pywrap_tensorflow.TFE_Py_TapeGradient(
tape._tape, # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6d3ef9a37b..9789dbadee 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1836,6 +1836,8 @@ bool OpGradientDoesntRequireOutputIndices(
{"SoftplusGrad", {true, {}}},
{"Softsign", {true, {}}},
{"ReluGrad", {true, {}}},
+ {"LeakyRelu", {true, {}}},
+ {"LeakyReluGrad", {true, {}}},
{"Conv2D", {true, {}}},
{"DepthwiseConv2dNative", {true, {}}},
{"Dilation2D", {true, {}}},
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 5352796174..28a8286544 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2660,6 +2660,7 @@ class _EmbeddingColumn(
inputs=inputs,
weight_collections=weight_collections,
trainable=trainable)
+
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
sequence_length = _sequence_length_from_sparse_tensor(
sparse_tensors.id_tensor)
@@ -3383,6 +3384,16 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
def _verify_static_batch_size_equality(tensors, columns):
+ """Validates that the first dim (batch size) of all tensors are equal or None.
+
+ Args:
+ tensors: list of tensors to check.
+ columns: list of feature columns matching tensors. Will be used for error
+ messaging.
+
+ Raises:
+ ValueError: if one of the tensors has a variant batch size
+ """
# bath_size is a tf.Dimension object.
expected_batch_size = None
for i in range(0, len(tensors)):
@@ -3403,9 +3414,18 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
with ops.name_scope(None, 'sequence_length') as name_scope:
row_ids = sp_tensor.indices[:, 0]
column_ids = sp_tensor.indices[:, 1]
+ # Add one to convert column indices to element length
column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # Get the number of elements we will have per example/row
+ seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
+
+ # The raw values are grouped according to num_elements;
+ # how many entities will we have after grouping?
+ # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
+ # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
+ # these will get grouped, and the final seq_length is [1, 1]
+ seq_length = math_ops.to_int64(math_ops.ceil(seq_length / num_elements))
+
# If the last n rows do not have ids, seq_length will have shape
# [batch_size - n]. Pad the remaining values with zeros.
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
@@ -3439,25 +3459,14 @@ class _SequenceCategoricalColumn(
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
id_tensor = sparse_tensors.id_tensor
weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+
+ # Expands third dimension, if necessary so that embeddings are not
+ # combined during embedding lookup. If the tensor is already 3D, leave
+ # as-is.
+ shape = array_ops.shape(id_tensor)
+ target_shape = [shape[0], shape[1], -1]
+ id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index e85bba11cd..9955a9a2cd 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -482,7 +482,8 @@ class OpDefLibrary(object):
else:
raise TypeError("%s that don't all match." % prefix)
else:
- raise TypeError("%s that are invalid." % prefix)
+ raise TypeError(
+ "%s that are invalid. Tensors: %s" % (prefix, values))
types = [x.dtype for x in values]
inputs.extend(values)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8bb177939e..77c2bc930e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4140,10 +4140,7 @@ class Graph(object):
if op is None and not ignore_existing:
raise ValueError("Trying to reset colocation (op is None) but "
"ignore_existing is not True")
-
- if op is not None and not isinstance(op, Operation):
- # We always want to colocate with the reference op.
- op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
+ op = _op_to_colocate_with(op)
# By default, colocate_with resets the device function stack,
# since colocate_with is typically used in specific internal
@@ -6168,4 +6165,27 @@ def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
name, as_ref))
+def _op_to_colocate_with(v):
+ """Operation object corresponding to v to use for colocation constraints."""
+ if v is None:
+ return None
+ if isinstance(v, Operation):
+ return v
+ # We always want to colocate with the reference op.
+ # When 'v' is a ResourceVariable, the reference op is the handle creating op.
+ #
+ # What this should be is:
+ # if isinstance(v, ResourceVariable):
+ # return v.handle.op
+ # However, that would require a circular import dependency.
+ # As of October 2018, there were attempts underway to remove
+ # colocation constraints altogether. Assuming that will
+ # happen soon, perhaps this hack to work around the circular
+ # import dependency is acceptable.
+ if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
+ v.handle.op, Operation):
+ return v.handle.op
+ return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op
+
+
register_tensor_conversion_function(Operation, _operation_conversion_error)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4ec4b41b5e..95925bb471 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -506,9 +506,9 @@ def disable_control_flow_v2(unused_msg):
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
- Runs the test multiple times executing eagerly, first as a warmup and then
- several times to let objects accumulate. The warmup helps ignore caches which
- do not grow as the test is run repeatedly.
+ Runs the test multiple times executing eagerly, first as a warmup and then to
+ let objects accumulate. The warmup helps ignore caches which do not grow as
+ the test is run repeatedly.
Useful for checking that there are no missing Py_DECREFs in the C exercised by
a bit of Python.
@@ -518,7 +518,14 @@ def assert_no_new_pyobjects_executing_eagerly(f):
"""Warms up, gets an object count, runs the test, checks for new objects."""
with context.eager_mode():
gc.disable()
- f(self, **kwargs)
+ # Run the test 2 times as warmup, in an attempt to fill up caches, which
+ # should not grow as the test is run repeatedly below.
+ #
+ # TODO(b/117156879): Running warmup twice is black magic; we have seen
+ # tests that fail with 1 warmup run, and pass with 2, on various versions
+ # of python2.7.x.
+ for _ in range(2):
+ f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
if ops.has_default_graph():
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4a72c4b3f3..c4d23f117f 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -62,6 +62,7 @@ py_library(
":backend",
":engine",
":layers",
+ ":optimizer_v2",
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
@@ -189,6 +190,30 @@ py_library(
],
)
+py_library(
+ name = "optimizer_v2",
+ srcs = [
+ "optimizer_v2/adadelta.py",
+ "optimizer_v2/adagrad.py",
+ "optimizer_v2/adam.py",
+ "optimizer_v2/optimizer_v2.py",
+ "optimizer_v2/rmsprop.py",
+ "optimizer_v2/sgd.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distribute",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_test(
name = "integration_test",
size = "medium",
@@ -827,3 +852,133 @@ py_library(
"//third_party/py/numpy",
],
)
+
+cuda_py_test(
+ name = "adadelta_test",
+ size = "medium",
+ srcs = ["optimizer_v2/adadelta_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adagrad_test",
+ size = "small",
+ srcs = ["optimizer_v2/adagrad_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adam_test",
+ size = "small",
+ srcs = ["optimizer_v2/adam_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "checkpointable_utils_test",
+ srcs = ["optimizer_v2/checkpointable_utils_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:layers_base",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/keras",
+ ],
+ tags = ["notsan"],
+)
+
+cuda_py_test(
+ name = "sgd_test",
+ size = "medium",
+ srcs = ["optimizer_v2/sgd_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "optimizer_v2_test",
+ size = "medium",
+ srcs = ["optimizer_v2/optimizer_v2_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
+ name = "rmsprop_test",
+ size = "small",
+ srcs = ["optimizer_v2/rmsprop_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+ tags = ["optonly"],
+)
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index 99645de736..d69791ce8d 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -160,6 +160,11 @@ def sigmoid(x):
return nn.sigmoid(x)
+@tf_export('keras.activations.exponential')
+def exponential(x):
+ return math_ops.exp(x)
+
+
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
"""Hard sigmoid activation function.
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index dd0bbcff39..ad238cb0a9 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -169,6 +169,16 @@ class KerasActivationsTest(test.TestCase):
expected = np.tanh(test_values)
self.assertAllClose(result, expected, rtol=1e-05)
+ def test_exponential(self):
+ with self.cached_session():
+ test_values = np.random.random((2, 5))
+ x = keras.backend.placeholder(ndim=2)
+ exp = keras.activations.exponential(x)
+ f = keras.backend.function([x], [exp])
+ result = f([test_values])[0]
+ expected = np.exp(test_values)
+ self.assertAllClose(result, expected, rtol=1e-05)
+
def test_linear(self):
x = np.random.random((10, 5))
self.assertAllClose(x, keras.activations.linear(x))
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 8ebf7356cd..13f52fbae7 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -774,6 +774,8 @@ def is_keras_tensor(x):
Examples:
```python
+ >>> import tensorflow as tf
+ >>> import numpy
>>> from keras import backend as K
>>> from keras.layers import Input, Dense
>>> np_var = numpy.array([1, 2])
@@ -2221,7 +2223,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@tf_export('keras.backend.batch_normalization')
-def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
+def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma.
I.e. returns:
@@ -2233,11 +2235,49 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
var: Variance of batch.
beta: Tensor with which to center the input.
gamma: Tensor by which to scale the input.
+ axis: Integer, the axis that should be normalized.
+ (typically the features axis).
epsilon: Fuzz factor.
Returns:
A tensor.
"""
+ if ndim(x) == 4:
+ # The CPU implementation of `fused_batch_norm` only supports NHWC
+ if axis == 1 or axis == -3:
+ tf_data_format = 'NCHW'
+ elif axis == 3 or axis == -1:
+ tf_data_format = 'NHWC'
+ else:
+ tf_data_format = None
+
+ if (tf_data_format == 'NHWC' or
+ tf_data_format == 'NCHW' and _has_nchw_support()):
+ # The mean / var / beta / gamma tensors may be broadcasted
+ # so they may have extra axes of size 1, which should be squeezed.
+ if ndim(mean) > 1:
+ mean = array_ops.reshape(mean, [-1])
+ if ndim(var) > 1:
+ var = array_ops.reshape(var, [-1])
+ if beta is None:
+ beta = zeros_like(mean)
+ elif ndim(beta) > 1:
+ beta = array_ops.reshape(beta, [-1])
+ if gamma is None:
+ gamma = ones_like(mean)
+ elif ndim(gamma) > 1:
+ gamma = array_ops.reshape(gamma, [-1])
+ y, _, _ = nn.fused_batch_norm(
+ x,
+ gamma,
+ beta,
+ epsilon=epsilon,
+ mean=mean,
+ variance=var,
+ data_format=tf_data_format,
+ is_training=False
+ )
+ return y
return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
@@ -2878,7 +2918,7 @@ class Function(object):
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
- 'time: %s', session_kwargs.keys())
+ 'time: %s', (session_kwargs.keys(),))
self._callable_fn = None
self._feed_arrays = None
@@ -3796,19 +3836,23 @@ def _preprocess_conv1d_input(x, data_format):
return x, tf_data_format
-def _preprocess_conv2d_input(x, data_format):
+def _preprocess_conv2d_input(x, data_format, force_transpose=False):
"""Transpose and cast the input before the conv2d.
Arguments:
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
+ force_transpose: Boolean. If True, the input will always be transposed
+ from NCHW to NHWC if `data_format` is `"channels_first"`.
+ If False, the transposition only occurs on CPU (GPU ops are
+ assumed to support NCHW).
Returns:
A tensor.
"""
tf_data_format = 'NHWC'
if data_format == 'channels_first':
- if not _has_nchw_support():
+ if not _has_nchw_support() or force_transpose:
x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
else:
tf_data_format = 'NCHW'
@@ -3956,7 +4000,8 @@ def conv2d_transpose(x,
output_shape,
strides=(1, 1),
padding='valid',
- data_format=None):
+ data_format=None,
+ dilation_rate=(1, 1)):
"""2D deconvolution (i.e.
transposed convolution).
@@ -3970,6 +4015,7 @@ def conv2d_transpose(x,
data_format: string, `"channels_last"` or `"channels_first"`.
Whether to use Theano or TensorFlow/CNTK data format
for inputs/kernels/outputs.
+ dilation_rate: Tuple of 2 integers.
Returns:
A tensor, result of transposed 2D convolution.
@@ -3985,7 +4031,13 @@ def conv2d_transpose(x,
if isinstance(output_shape, (tuple, list)):
output_shape = array_ops.stack(output_shape)
- x, tf_data_format = _preprocess_conv2d_input(x, data_format)
+ # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
+ if data_format == 'channels_first' and dilation_rate != (1, 1):
+ force_transpose = True
+ else:
+ force_transpose = False
+
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
output_shape = (output_shape[0], output_shape[2], output_shape[3],
@@ -4000,13 +4052,18 @@ def conv2d_transpose(x,
else:
strides = (1, 1) + strides
- x = nn.conv2d_transpose(
- x,
- kernel,
- output_shape,
- strides,
- padding=padding,
- data_format=tf_data_format)
+ if dilation_rate == (1, 1):
+ x = nn.conv2d_transpose(x, kernel, output_shape, strides,
+ padding=padding,
+ data_format=tf_data_format)
+ else:
+ assert dilation_rate[0] == dilation_rate[1]
+ x = nn.atrous_conv2d_transpose(
+ x,
+ kernel,
+ output_shape,
+ rate=dilation_rate[0],
+ padding=padding)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index ab71589940..0834448699 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -26,6 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import nn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -1381,6 +1382,36 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
+ def test_batch_normalization(self):
+ g_val = np.random.random((3,))
+ b_val = np.random.random((3,))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+
+ # 3D NHC case
+ val = np.random.random((10, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 3])
+
+ # 4D NHWC case
+ val = np.random.random((10, 5, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1, 2), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
+
+ # 4D NCHW case
+ val = np.random.random((10, 3, 5, 5))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 2, 3), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
+
class TestCTC(test.TestCase):
@@ -1506,12 +1537,13 @@ class TestRandomOps(test.TestCase):
self.assertAllClose(np.min(y), -2., atol=0.1)
def test_string_input(self):
- seq = keras.Sequential([
- keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
- keras.layers.Lambda(lambda x: x[0])
- ])
- preds = seq.predict([['tensorflow eager']])
- self.assertEqual(preds.shape, (1,))
+ with self.cached_session():
+ seq = keras.Sequential([
+ keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
+ keras.layers.Lambda(lambda x: x[0])
+ ])
+ preds = seq.predict([['tensorflow eager']])
+ self.assertEqual(preds.shape, (1,))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 6dfbbf3694..3d6000f223 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -781,6 +781,10 @@ class LearningRateScheduler(Callback):
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs or {}
+ logs['lr'] = K.get_value(self.model.optimizer.lr)
+
@tf_export('keras.callbacks.TensorBoard')
class TensorBoard(Callback):
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 918488bd7a..5969fea2b2 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1641,10 +1641,11 @@ class Network(base_layer.Layer):
ValueError: if `summary()` is called before the model is built.
"""
if not self.built:
- raise ValueError('This model has never been called, thus its weights '
- 'have not yet been created, so no summary can be '
- 'displayed. Build the model first '
- '(e.g. by calling it on some data).')
+ raise ValueError('This model has not yet been built. '
+ 'Build the model first by calling `build()` or calling '
+ '`fit()` with some data, or specify '
+ 'an `input_shape` argument in the first layer(s) for '
+ 'automatic build.')
layer_utils.print_summary(self,
line_length=line_length,
positions=positions,
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 2ebb4cf99f..ff2ae54ad4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -563,9 +563,11 @@ class Model(Network):
for name in self.output_names:
tmp_target_tensors.append(target_tensors.get(name, None))
target_tensors = tmp_target_tensors
+ elif tensor_util.is_tensor(target_tensors):
+ target_tensors = [target_tensors]
else:
- raise TypeError('Expected `target_tensors` to be '
- 'a list or dict, but got:', target_tensors)
+ raise TypeError('Expected `target_tensors` to be a list or tuple or '
+ 'dict or a single tensor, but got:', target_tensors)
for i in range(len(self.outputs)):
if i in skip_target_indices:
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 04e8d079c0..ac759ef3aa 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -820,10 +820,6 @@ def _clone_and_build_model(model, inputs=None, targets=None):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
- # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
- # single tensor should be OK but it throws an error in that case.
- if targets is not None and not isinstance(targets, (list, dict, tuple)):
- targets = [targets]
if isinstance(targets, tuple):
targets = nest.flatten(targets)
cloned_model.compile(
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 54ad74c08b..868fd1dc69 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1865,6 +1865,10 @@ class TestTrainingWithDataTensors(test.TestCase):
model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target])
model.train_on_batch(input_val, None)
+ # single-output, as single tensor
+ model.compile(optimizer='rmsprop', loss='mse', target_tensors=target)
+ model.train_on_batch(input_val, None)
+
# single-output, as dict
model.compile(optimizer='rmsprop', loss='mse',
target_tensors={'dense': target})
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index d00def07bb..8f5872385c 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -645,6 +645,14 @@ class Conv2DTranspose(Conv2D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 2 integers,
+ specifying the amount of padding along the height and width
+ of the output tensor.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -700,7 +708,9 @@ class Conv2DTranspose(Conv2D):
kernel_size,
strides=(1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
+ dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
@@ -717,6 +727,7 @@ class Conv2DTranspose(Conv2D):
strides=strides,
padding=padding,
data_format=data_format,
+ dilation_rate=dilation_rate,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=initializers.get(kernel_initializer),
@@ -728,6 +739,16 @@ class Conv2DTranspose(Conv2D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 2, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4:
@@ -769,51 +790,50 @@ class Conv2DTranspose(Conv2D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, h_axis, w_axis = 1, 2, 3
+ h_axis, w_axis = 2, 3
else:
- c_axis, h_axis, w_axis = 3, 1, 2
+ h_axis, w_axis = 1, 2
height, width = inputs_shape[h_axis], inputs_shape[w_axis]
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
- strides = (1, 1, stride_h, stride_w)
else:
output_shape = (batch_size, out_height, out_width, self.filters)
- strides = (1, stride_h, stride_w, 1)
output_shape_tensor = array_ops.stack(output_shape)
- outputs = nn.conv2d_transpose(
+ outputs = backend.conv2d_transpose(
inputs,
self.kernel,
output_shape_tensor,
- strides,
- padding=self.padding.upper(),
- data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dilation_rate=self.dilation_rate)
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -837,13 +857,33 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv2DTranspose, self).get_config()
+ config['output_padding'] = self.output_padding
+ return config
+
@tf_export('keras.layers.Conv3DTranspose',
'keras.layers.Convolution3DTranspose')
@@ -878,6 +918,14 @@ class Conv3DTranspose(Conv3D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 3 integers,
+ specifying the amount of padding along the depth, height, and
+ width.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -943,6 +991,7 @@ class Conv3DTranspose(Conv3D):
kernel_size,
strides=(1, 1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
activation=None,
use_bias=True,
@@ -971,6 +1020,16 @@ class Conv3DTranspose(Conv3D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 3, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 5:
@@ -1012,11 +1071,9 @@ class Conv3DTranspose(Conv3D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
+ d_axis, h_axis, w_axis = 2, 3, 4
else:
- c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
-
- self.input_spec = InputSpec(ndim=5, axes={c_axis: inputs_shape[c_axis]})
+ d_axis, h_axis, w_axis = 1, 2, 3
depth = inputs_shape[d_axis]
height = inputs_shape[h_axis]
@@ -1025,19 +1082,27 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_depth = conv_utils.deconv_output_length(depth,
kernel_d,
- self.padding,
- stride_d)
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_depth, out_height,
out_width)
@@ -1058,20 +1123,7 @@ class Conv3DTranspose(Conv3D):
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[d_axis] = conv_utils.deconv_output_length(out_shape[d_axis],
- kernel_d,
- self.padding,
- stride_d)
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -1109,15 +1161,38 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[d_axis] = conv_utils.deconv_output_length(
- output_shape[d_axis], kernel_d, self.padding, stride_d)
+ output_shape[d_axis],
+ kernel_d,
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv3DTranspose, self).get_config()
+ config.pop('dilation_rate')
+ config['output_padding'] = self.output_padding
+ return config
+
class SeparableConv(Conv):
"""Abstract base layer for separable nD convolution.
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index cad5e4c8bd..f88d632ab5 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -204,6 +204,9 @@ class Conv2DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1)])
+
def test_conv2dtranspose_regularizers(self):
kwargs = {
'filters': 3,
@@ -239,6 +242,31 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(layer.kernel.constraint, k_constraint)
self.assertEqual(layer.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_conv2d_transpose_dilation(self):
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ kwargs={'filters': 2,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2)},
+ input_shape=(2, 5, 6, 3))
+
+ input_data = np.arange(48).reshape((1, 4, 4, 3)).astype(np.float32)
+ expected_output = np.float32([[192, 228, 192, 228],
+ [336, 372, 336, 372],
+ [192, 228, 192, 228],
+ [336, 372, 336, 372]]).reshape((1, 4, 4, 1))
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ input_data=input_data,
+ kwargs={'filters': 1,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2),
+ 'kernel_initializer': 'ones'},
+ expected_output=expected_output)
+
class Conv3DTransposeTest(test.TestCase):
@@ -270,6 +298,9 @@ class Conv3DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1, 1)])
+
def test_conv3dtranspose_regularizers(self):
kwargs = {
'filters': 3,
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 912e8bd619..72a9c1d629 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@@ -41,16 +44,18 @@ class Pooling1D(Layer):
strides of the pooling operation.
padding: A string. The padding method, either 'valid' or 'same'.
Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
name: A string, the name of the layer.
"""
def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format=None,
+ padding='valid', data_format='channels_last',
name=None, **kwargs):
super(Pooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
@@ -65,45 +70,39 @@ class Pooling1D(Layer):
self.input_spec = InputSpec(ndim=3)
def call(self, inputs):
- # There is no TF op for 1D pooling, hence we make the inputs 4D.
- if self.data_format == 'channels_last':
- # input is NWC, make it NHWC
- inputs = array_ops.expand_dims(inputs, 1)
- # pool on the W dim
- pool_shape = (1, 1) + self.pool_size + (1,)
- strides = (1, 1) + self.strides + (1,)
- data_format = 'NHWC'
- else:
- # input is NCW, make it NCHW
- inputs = array_ops.expand_dims(inputs, 2)
- # pool on the W dim
- pool_shape = (1, 1, 1) + self.pool_size
- strides = (1, 1, 1) + self.strides
- data_format = 'NCHW'
-
+ pad_axis = 2 if self.data_format == 'channels_last' else 3
+ inputs = array_ops.expand_dims(inputs, pad_axis)
outputs = self.pool_function(
inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=data_format)
-
- if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 1)
- else:
- return array_ops.squeeze(outputs, 2)
+ self.pool_size + (1,),
+ strides=self.strides + (1,),
+ padding=self.padding,
+ data_format=self.data_format)
+ return array_ops.squeeze(outputs, pad_axis)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
- self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
+ if self.data_format == 'channels_first':
+ steps = input_shape[2]
+ features = input_shape[1]
+ else:
+ steps = input_shape[1]
+ features = input_shape[2]
+ length = conv_utils.conv_output_length(steps,
+ self.pool_size[0],
+ self.padding,
+ self.strides[0])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], features, length])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], length, features])
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
- 'padding': self.padding
+ 'padding': self.padding,
+ 'data_format': self.data_format,
}
base_config = super(Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -119,19 +118,36 @@ class MaxPooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(MaxPooling1D, self).__init__(
- nn.max_pool,
+ functools.partial(backend.pool2d, pool_mode='max'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -149,18 +165,35 @@ class AveragePooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(AveragePooling1D, self).__init__(
- nn.avg_pool,
+ functools.partial(backend.pool2d, pool_mode='avg'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -561,41 +594,96 @@ class GlobalPooling1D(Layer):
"""Abstract class for different global pooling 1D layers.
"""
- def __init__(self, **kwargs):
+ def __init__(self, data_format='channels_last', **kwargs):
super(GlobalPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
+ self.data_format = conv_utils.normalize_data_format(data_format)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
def call(self, inputs):
raise NotImplementedError
+ def get_config(self):
+ config = {'data_format': self.data_format}
+ base_config = super(GlobalPooling1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export('keras.layers.GlobalAveragePooling1D',
'keras.layers.GlobalAvgPool1D')
class GlobalAveragePooling1D(GlobalPooling1D):
"""Global average pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
`(batch_size, features)`
"""
- def call(self, inputs):
- return backend.mean(inputs, axis=1)
+ def __init__(self, data_format='channels_last', **kwargs):
+ super(GlobalAveragePooling1D, self).__init__(data_format=data_format,
+ **kwargs)
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ if mask is not None:
+ mask = math_ops.cast(mask, backend.floatx())
+ input_shape = inputs.shape.as_list()
+ broadcast_shape = [-1, input_shape[steps_axis], 1]
+ mask = array_ops.reshape(mask, broadcast_shape)
+ inputs *= mask
+ return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum(
+ mask, axis=steps_axis)
+ else:
+ return backend.mean(inputs, axis=steps_axis)
+
+ def compute_mask(self, inputs, mask=None):
+ return None
@tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D')
class GlobalMaxPooling1D(GlobalPooling1D):
"""Global max pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
@@ -603,7 +691,8 @@ class GlobalMaxPooling1D(GlobalPooling1D):
"""
def call(self, inputs):
- return backend.max(inputs, axis=1)
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ return backend.max(inputs, axis=steps_axis)
class GlobalPooling2D(Layer):
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index 2cd9939e66..936e73ecf9 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -18,11 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
class GlobalPoolingTest(test.TestCase):
@@ -31,8 +34,26 @@ class GlobalPoolingTest(test.TestCase):
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalAveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_globalpooling_1d_masking_support(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Masking(mask_value=0., input_shape=(3, 4)))
+ model.add(keras.layers.GlobalAveragePooling1D())
+ model.compile(loss='mae', optimizer=rmsprop.RMSPropOptimizer(0.001))
+
+ model_input = np.random.random((2, 3, 4))
+ model_input[0, 1:, :] = 0
+ output = model.predict(model_input)
+ self.assertAllClose(output[0], model_input[0, 0, :])
@tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_2d(self):
@@ -172,6 +193,10 @@ class Pooling1DTest(test.TestCase):
kwargs={'strides': stride,
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.MaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
@tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_1d(self):
@@ -183,6 +208,11 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.AveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index a1933c11b0..d19d0b5f8c 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -587,6 +587,9 @@ class Bidirectional(Wrapper):
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
+ else:
+ raise ValueError(
+ 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
# Properly set learning phase
if (getattr(y, '_uses_learning_phase', False) or
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
new file mode 100644
index 0000000000..d3b3c9c12e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -0,0 +1,116 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adadelta for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.training import training_ops
+
+
+class Adadelta(optimizer_v2.OptimizerV2):
+ """Adadelta optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
+ ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate. It is recommended
+ to leave it at the default value.
+ rho: float hyperparameter >= 0. The decay rate.
+ epsilon: float hyperparameter >= 0. Fuzz factor. A constant epsilon used
+ to better condition the grad update.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adadelta'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.95,
+ epsilon=1e-8,
+ name="Adadelta"):
+ super(Adadelta, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("epsilon", epsilon)
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ state.zeros_slot(v, "accum")
+ state.zeros_slot(v, "accum_update")
+
+ def _apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.sparse_apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_sparse_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
new file mode 100644
index 0000000000..6e48f92e4f
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
@@ -0,0 +1,166 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adadelta Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdadeltaOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ num_updates = 4 # number of ADADELTA steps to perform
+ for dtype in [dtypes.half, dtypes.float32]:
+ for grad in [0.2, 0.1, 0.01]:
+ for lr in [1.0, 0.5, 0.1]:
+ with self.cached_session():
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+ else:
+ var0 = variables.Variable(var0_init, dtype=dtype)
+ var1 = variables.Variable(var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ adadelta_opt = adadelta.Adadelta(lr, rho, epsilon)
+ adadelta_update = adadelta_opt.apply_gradients(
+ zip([grads, grads], [var0, var1]))
+
+ opt_vars = adadelta_opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[2].name, var1._shared_name)
+ self.assertStartsWith(opt_vars[3].name, var1._shared_name)
+ self.assertEqual(4, len(opt_vars))
+
+ variables.global_variables_initializer().run()
+
+ # Assign slots
+ slot = [None] * 2
+ slot_update = [None] * 2
+ self.assertEqual(["accum", "accum_update"],
+ adadelta_opt.get_slot_names())
+ slot[0] = adadelta_opt.get_slot(var0, "accum")
+ self.assertEquals(slot[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot[0] in variables.trainable_variables())
+
+ slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
+ self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot_update[0] in variables.trainable_variables())
+
+ slot[1] = adadelta_opt.get_slot(var1, "accum")
+ self.assertEquals(slot[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot[1] in variables.trainable_variables())
+
+ slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
+ self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot_update[1] in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, var0.eval())
+ self.assertAllClose(var1_init, var1.eval())
+
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ adadelta_update.run()
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (accum_update * rho + (update[step]**2) *
+ (1.0 - rho))
+ tot_update += update[step] * lr
+
+ # Check that the accumulators have been updated
+ for slot_idx in range(2):
+ self.assertAllCloseAccordingToType(
+ np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
+ slot[slot_idx].eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [accum_update, accum_update],
+ dtype=dtype.as_numpy_dtype()),
+ slot_update[slot_idx].eval(),
+ rtol=1e-5)
+
+ # Check that the parameters have been updated
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var0_init[0] - tot_update, var0_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var0.eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var1_init[0] - tot_update, var1_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var1.eval(),
+ rtol=1e-5)
+
+ def testBasic(self):
+ self.doTestBasic(use_resource=False)
+
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adadelta.Adadelta(1.0, 1.0, 1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-111, -138]], var0.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
new file mode 100644
index 0000000000..2d8cec2300
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -0,0 +1,119 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adagrad optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import training_ops
+
+
+class Adagrad(optimizer_v2.OptimizerV2):
+ """Adagrad optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
+ or this
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
+
+ The learning_rate arg below is a hyperparameter, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ initial_accumulator_value: A floating point value. Starting value for the
+ accumulators, must be positive.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adagrad'.
+
+ Raises:
+ ValueError: If the `initial_accumulator_value` is invalid.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ initial_accumulator_value=0.1,
+ name="Adagrad"):
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value must be positive: %s" %
+ initial_accumulator_value)
+ super(Adagrad, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+
+ self._initial_accumulator_value = initial_accumulator_value
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
+ state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
+ "accumulator")
+
+ def _apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_sparse_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
new file mode 100644
index 0000000000..fc4ef5c399
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
@@ -0,0 +1,276 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for aggregate operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdagradOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testBasic(self):
+ self.doTestBasic()
+
+ def testBasicResource(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adagrad.Adagrad(1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType(
+ [[1.0, 2.0], [3.0, 4.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0, 1], [3, 4]], var0.eval(), atol=0.01)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(
+ constant_op.constant(3.0), initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([[-1.6026098728179932], [2.0]]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[3.0], [3.715679168701172]]), var1.eval())
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def testSparseRepeatedIndicesResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var_repeated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_repeated = math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_repeated, [0, 0]))
+ var_aggregated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_aggregated = 2 * math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_aggregated, [0]))
+ update_op_repeated = adagrad.Adagrad(2.0).minimize(loss_repeated)
+ update_op_aggregated = adagrad.Adagrad(2.0).minimize(loss_aggregated)
+ variables.global_variables_initializer().run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+ for _ in range(3):
+ update_op_repeated.run()
+ update_op_aggregated.run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+
+ def testSparseStability(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ shape = [1, 6]
+ var0 = variables.Variable(
+ [[
+ 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257,
+ -0.0105945
+ ]],
+ dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[
+ -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05
+ ]],
+ shape=shape,
+ dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant(shape))
+ ada_opt = adagrad.Adagrad(1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = variables.global_variables_initializer()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllCloseAccordingToType(
+ np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[
+ 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573,
+ -0.01029443
+ ]]), var0.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0)
+ # Apply the optimizer twice. Both applications will use
+ # the same accums.
+ ada_update1 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 Adagrad).
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testDynamicShapeVariable_Ok(self):
+ with self.cached_session():
+ v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
+ validate_shape=False)
+ self.assertFalse(v.shape.is_fully_defined())
+ # Creating optimizer should cause no exception.
+ adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
new file mode 100644
index 0000000000..8367228d7a
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -0,0 +1,203 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adam optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import training_ops
+
+
+class Adam(optimizer_v2.OptimizerV2):
+ r"""Adam Optimizer.
+
+ Default parameters follow those provided in the original paper.
+
+ See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+
+ Some of the args below are hyperparameters where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
+
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ beta_1: float hyperparameter, 0 < beta_1 < 1. Generally close to 1. The
+ exponential decay rate for the 1st moment estimates.
+ beta_2: float hyperparameter, 0 < beta_2 < 1. Generally close to 1. The
+ exponential decay rate for the 2nd moment estimates.
+ epsilon: float hyperparameter >= 0. Fuzz factor. This epsilon is "epsilon
+ hat" in the Kingma and Ba paper (in the formula just before Section
+ 2.1), not the epsilon in Algorithm 1 of the paper.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-8,
+ name="Adam"):
+ super(Adam, self).__init__(name)
+
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("beta_1", beta_1)
+ self._set_hyper("beta_2", beta_2)
+ self._set_hyper("epsilon", epsilon)
+
+ def _get_beta_accumulators(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return (state.get_non_slot("beta_1_power"),
+ state.get_non_slot("beta_2_power"))
+
+ def _create_vars(self, var_list, state):
+ # Non-slot variables end up on the same device(s).
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_1"), name="beta_1_power")
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_2"), name="beta_2_power")
+
+ # Create slots for the first and second moments.
+ for v in var_list:
+ state.zeros_slot(v, "m")
+ state.zeros_slot(v, "v")
+
+ def _apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.apply_adam(
+ var,
+ m,
+ v,
+ math_ops.cast(beta_1_power, var.dtype.base_dtype),
+ math_ops.cast(beta_2_power, var.dtype.base_dtype),
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("beta_1", var.dtype.base_dtype),
+ state.get_hyper("beta_2", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(beta_1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta_2_power, grad.dtype.base_dtype),
+ state.get_hyper("learning_rate", grad.dtype.base_dtype),
+ state.get_hyper("beta_1", grad.dtype.base_dtype),
+ state.get_hyper("beta_2", grad.dtype.base_dtype),
+ state.get_hyper("epsilon", grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ beta_1_power = math_ops.cast(beta_1_power, var.dtype.base_dtype)
+ beta_2_power = math_ops.cast(beta_2_power, var.dtype.base_dtype)
+ lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
+ beta_1_t = state.get_hyper("beta_1", var.dtype.base_dtype)
+ beta_2_t = state.get_hyper("beta_2", var.dtype.base_dtype)
+ epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
+ # m_t = beta_1 * m + (1 - beta_1) * g_t
+ m = state.get_slot(var, "m")
+ m_scaled_g_values = grad * (1 - beta_1_t)
+ m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
+ # v_t = beta_2 * v + (1 - beta_2) * (g_t * g_t)
+ v = state.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
+ v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(var,
+ lr * m_t / (v_sqrt + epsilon_t),
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _apply_sparse(self, grad, var, state):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking),
+ state)
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(
+ x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ return self._apply_sparse_shared(
+ grad, var, indices, self._resource_scatter_add, state)
+
+ def _finish(self, state):
+ # Update the power accumulators.
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ update_beta_1 = beta_1_power.assign(
+ beta_1_power * state.get_hyper("beta_1"), use_locking=self._use_locking)
+ update_beta_2 = beta_2_power.assign(
+ beta_2_power * state.get_hyper("beta_2"), use_locking=self._use_locking)
+ return control_flow_ops.group(update_beta_1, update_beta_2)
diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py
new file mode 100644
index 0000000000..77796317a1
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam_test.py
@@ -0,0 +1,333 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adam optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adam_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+
+class AdamOptimizerTest(test.TestCase):
+
+ def doTestSparse(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([2]))
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
+ def testSparseDevicePlacement(self):
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(force_gpu=test.is_gpu_available()):
+ # If a GPU is available, tests that all optimizer ops can be placed on
+ # it (i.e. they have GPU kernels).
+ var = variables.Variable([[1.0], [2.0]])
+ indices = constant_op.constant([0, 1], dtype=index_dtype)
+ gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
+ optimizer = adam.Adam(3.0)
+ minimize_op = optimizer.minimize(gathered_sum)
+ variables.global_variables_initializer().run()
+ minimize_op.run()
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adam.Adam().apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adam.Adam().apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def doTestBasic(self, use_resource=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertTrue(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = adam.Adam()
+ g = ops.Graph()
+ with g.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = adam.Adam(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
new file mode 100644
index 0000000000..338c04148b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
@@ -0,0 +1,761 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(josh11b): Forked from contrib/eager/python to test OptimizerV2 the same way
+# OptimizerV1 is tested. This file should be removed once the fork is resolved.
+
+import functools
+import os
+
+import six
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as core_saver
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+
+
+class NonLayerCheckpointable(tracking.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ self.a_variable = util.add_variable(
+ self, name="a_variable", shape=[])
+
+
+# pylint: disable=not-callable
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
+class _MirroringSaveable(
+ core_saver.BaseSaverBuilder.ResourceVariableSaveable):
+
+ def __init__(self, primary_variable, mirrored_variable, name):
+ self._primary_variable = primary_variable
+ self._mirrored_variable = mirrored_variable
+ super(_MirroringSaveable, self).__init__(
+ self._primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group(
+ self._primary_variable.assign(tensor),
+ self._mirrored_variable.assign(tensor))
+
+
+class CheckpointingTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testNamingWithOptimizer(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ # A nuisance Model using the same optimizer. Its slot variables should not
+ # go in the checkpoint, since it is never depended on.
+ other_model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value),
+ global_step=optimizer_step)
+ optimizer.minimize(
+ lambda: other_model(input_value),
+ global_step=optimizer_step)
+ else:
+ train_op = optimizer.minimize(
+ model(input_value), global_step=optimizer_step)
+ optimizer.minimize(
+ other_model(input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ named_variables, serialized_graph, _ = (
+ util._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
+ expected_checkpoint_names = (
+ # Created in the root node, so no prefix.
+ "optimizer_step",
+ "model/_second/kernel",
+ "model/_named_dense/kernel",
+ "model/_named_dense/bias",
+ # non-Layer dependency of the model
+ "model/_non_layer/a_variable",
+ # The optimizer creates two non-slot variables
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
+ # Slot variables
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
+ )
+ suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
+ expected_checkpoint_names = [
+ name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
+ six.assertCountEqual(self, expected_checkpoint_names,
+ named_variables.keys())
+ # Check that we've mapped to the right variable objects (not exhaustive)
+ self.assertEqual(
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
+ self.assertEqual(
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
+ self.assertEqual(
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
+ # Spot check the generated protocol buffers.
+ self.assertEqual("optimizer",
+ serialized_graph.nodes[0].children[1].local_name)
+ optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
+ 1].node_id]
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .original_variable_node_id]
+ .attributes[0].full_name)
+ # We strip off the :0 suffix, as variable.name-based saving does.
+ self.assertEqual(
+ "my_model/dense/kernel/Adam",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .slot_variable_node_id]
+ .attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel/Adam:0",
+ optimizer.get_slot(
+ var=model._named_dense.kernel,
+ name="m").name)
+ self.assertEqual(
+ "model/_named_dense/kernel" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .original_variable_node_id].attributes[0].checkpoint_key)
+ self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
+ self.assertEqual(
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .slot_variable_node_id].attributes[0].checkpoint_key)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestore(self):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model)
+ input_value = constant_op.constant([[3.]])
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value))
+ else:
+ train_op = optimizer.minimize(model(input_value))
+ # TODO(allenl): Make initialization more pleasant when graph building.
+ root_checkpointable.save_counter # pylint: disable=pointless-statement
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
+ m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
+ self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
+ save_path = root_checkpointable.save(file_prefix=prefix)
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
+ self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
+ optimizer_variables = self.evaluate(optimizer.variables())
+ self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
+ # Immediate restoration
+ status = root_checkpointable.restore(save_path=save_path).assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
+ self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
+ self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
+ if not context.executing_eagerly():
+ return # Restore-on-create is only supported when executing eagerly
+ on_create_model = MyModel()
+ on_create_optimizer = adam.Adam(
+ 0.001,
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta_1=1.0,
+ beta_2=1.0)
+ on_create_root = util.Checkpoint(
+ optimizer=on_create_optimizer, model=on_create_model)
+ # Deferred restoration
+ status = on_create_root.restore(save_path=save_path)
+ on_create_model(constant_op.constant([[3.]])) # create variables
+ self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
+ self.assertAllEqual([42.],
+ self.evaluate(
+ on_create_model._named_dense.variables[1]))
+ on_create_m_bias_slot = on_create_optimizer.get_slot(
+ on_create_model._named_dense.variables[1], "m")
+ # Optimizer slot variables are created when the original variable is
+ # restored.
+ self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
+ self.assertAllEqual(optimizer_variables[2:],
+ self.evaluate(on_create_optimizer.variables()))
+ dummy_var = resource_variable_ops.ResourceVariable([1.])
+ on_create_optimizer.minimize(loss=dummy_var.read_value)
+ status.assert_consumed()
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
+
+ # TODO(allenl): Debug garbage created by this test in python3.
+ def testDeferredRestorationUsageEager(self):
+ """An idiomatic eager execution example."""
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ optimizer_step=training_util.get_or_create_global_step())
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
+ for _ in range(num_training_steps):
+ # TODO(allenl): Use a Dataset and serialize/checkpoint it.
+ input_value = constant_op.constant([[3.]])
+ optimizer.minimize(
+ lambda: model(input_value), # pylint: disable=cell-var-from-loop
+ global_step=root.optimizer_step)
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ root.optimizer_step.numpy())
+
+ def testUsageGraph(self):
+ """Expected usage when graph building."""
+ with context.graph_mode():
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default():
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ input_value = constant_op.constant([[3.]])
+ train_op = optimizer.minimize(
+ model(input_value),
+ global_step=root.global_step)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ with self.session(graph=ops.get_default_graph()) as session:
+ status = root.restore(save_path=checkpoint_path)
+ status.initialize_or_restore(session=session)
+ if checkpoint_path is None:
+ self.assertEqual(0, training_continuation)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ else:
+ status.assert_consumed()
+ for _ in range(num_training_steps):
+ session.run(train_op)
+ root.save(file_prefix=checkpoint_prefix, session=session)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ session.run(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ session.run(root.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAgnosticUsage(self):
+ """Graph/eager agnostic usage."""
+ # Does create garbage when executing eagerly due to ops.Graph() creation.
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+
+ # pylint: disable=cell-var-from-loop
+ @test_util.run_in_graph_and_eager_modes
+ def testWithDefun(self):
+ num_training_steps = 2
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ # Don't actually train so we can test variable values
+ optimizer = adam.Adam(0.)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ def train_fn():
+ @function.defun
+ def _call_model(x):
+ return model(x)
+ with backprop.GradientTape() as tape:
+ loss = _call_model(constant_op.constant([[3.]]))
+ gradients = tape.gradient(loss, model.variables)
+ return optimizer.apply_gradients(zip(gradients, model.variables),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(
+ self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ if training_continuation > 0:
+ status.assert_consumed()
+ self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
+ else:
+ self.evaluate(model.variables[0].assign([[42.]]))
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+ # pylint: enable=cell-var-from-loop
+
+ def testAnonymousVarsInInit(self):
+
+ class Model(training.Model):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.w = resource_variable_ops.ResourceVariable(0.0)
+ self.b = resource_variable_ops.ResourceVariable(0.0)
+ self.vars = [self.w, self.b]
+
+ def call(self, x):
+ return x * self.w + self.b
+
+ with context.eager_mode():
+ model = Model()
+ optimizer = adam.Adam(learning_rate=0.05)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ checkpoint = util.Checkpoint(
+ model=model, optimizer=optimizer)
+ for _ in range(2):
+ checkpoint.save(checkpoint_prefix)
+ with backprop.GradientTape() as tape:
+ loss = (constant_op.constant(1.)
+ - model(constant_op.constant(1.))) ** 2
+ grad = tape.gradient(loss, model.vars)
+ optimizer.apply_gradients(
+ [(g, v) for g, v in zip(grad, model.vars)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeferredSlotRestoration(self):
+ checkpoint_directory = self.get_temp_dir()
+
+ root = tracking.Checkpointable()
+ root.var = util.add_variable(
+ root, name="var", initializer=0.)
+ optimizer = adam.Adam(0.1)
+ if context.executing_eagerly():
+ optimizer.minimize(root.var.read_value)
+ else:
+ train_op = optimizer.minimize(root.var)
+ # Note that `optimizer` has not been added as a dependency of
+ # `root`. Create a one-off grouping so that slot variables for `root.var`
+ # get initialized too.
+ self.evaluate(util.gather_initializers(
+ util.Checkpoint(root=root, optimizer=optimizer)))
+ self.evaluate(train_op)
+ self.evaluate(state_ops.assign(root.var, 12.))
+ no_slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "no_slots"))
+ root.optimizer = optimizer
+ self.evaluate(state_ops.assign(root.var, 13.))
+ self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
+ 14.))
+ slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "with_slots"))
+ new_root = tracking.Checkpointable()
+ # Load the slot-containing checkpoint (deferred), then immediately overwrite
+ # the non-slot variable (also deferred).
+ slot_status = util.CheckpointableSaver(
+ new_root).restore(slots_path)
+ no_slot_status = util.CheckpointableSaver(
+ new_root).restore(no_slots_path)
+ with self.assertRaises(AssertionError):
+ no_slot_status.assert_consumed()
+ new_root.var = util.add_variable(
+ new_root, name="var", shape=[])
+ no_slot_status.assert_consumed()
+ no_slot_status.run_restore_ops()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ new_root.optimizer = adam.Adam(0.1)
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
+ slot_status.assert_consumed()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ if context.executing_eagerly():
+ # Slot variables are only created with restoring initializers when
+ # executing eagerly.
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ else:
+ self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
+ None)
+ if context.executing_eagerly():
+ new_root.optimizer.minimize(new_root.var.read_value)
+ else:
+ train_op = new_root.optimizer.minimize(new_root.var)
+ # The slot variable now exists; restore() didn't create it, but we should
+ # now have a restore op for it.
+ slot_status.run_restore_ops()
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ self.evaluate(train_op)
+ slot_status.assert_consumed()
+
+ def testManySavesGraph(self):
+ """Saves after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ saver.save(checkpoint_prefix)
+ before_ops = graph.get_operations()
+ saver.save(checkpoint_prefix)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testManyRestoresGraph(self):
+ """Restores after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ save_path = saver.save(checkpoint_prefix)
+ saver.restore(save_path)
+ before_ops = graph.get_operations()
+ saver.restore(save_path)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testMultipleGraphsNonSlotVariables(self):
+ with context.graph_mode():
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ optimizer = adam.Adam(0.001)
+ # Construct a model in one graph
+ first_graph = ops.Graph()
+ first_session = session_lib.Session(graph=first_graph)
+ with first_graph.as_default(), first_session.as_default():
+ first_variable = resource_variable_ops.ResourceVariable([1.])
+ first_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=first_variable)
+ train_op = optimizer.minimize(first_variable.read_value)
+ self.evaluate(util.gather_initializers(
+ first_root_checkpointable))
+ self.evaluate(train_op)
+ self.evaluate(first_variable.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+
+ # Save and load in a second graph
+ second_graph = ops.Graph()
+ with second_graph.as_default(), session_lib.Session(graph=second_graph):
+ second_variable = resource_variable_ops.ResourceVariable([1.])
+ second_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=second_variable)
+ train_op = optimizer.minimize(second_variable.read_value)
+ second_root_checkpointable.restore(None).initialize_or_restore()
+ self.evaluate(train_op)
+ self.evaluate(second_variable.assign([4.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([5.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
+ save_path = second_root_checkpointable.save(checkpoint_prefix)
+ self.evaluate(second_variable.assign([7.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([8.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+ status = second_root_checkpointable.restore(save_path)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([4.], self.evaluate(second_variable))
+ self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+
+ # Check that the first graph is unmolested
+ with first_graph.as_default(), first_session.as_default():
+ self.assertAllEqual([1.], self.evaluate(first_variable))
+ self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.Adam(0.0)
+ save_root = util.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.Adam(0.0)
+ load_root = util.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
+class CheckpointCompatibilityTests(test.TestCase):
+
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(model, input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(model._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=model._named_dense.bias, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")
+ .assign([102.]))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+ def _write_name_based_checkpoint(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ name_saver = core_saver.Saver()
+ return name_saver.save(
+ sess=session, save_path=checkpoint_prefix,
+ global_step=root.optimizer_step)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testLoadFromNameBasedSaver(self):
+ """Save a name-based checkpoint, load it using the object-based API."""
+ with test_util.device(use_gpu=True):
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = util.CheckpointableSaver(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ status.initialize_or_restore()
+ self._check_sentinels(root)
+
+ # TODO(allenl): Test for the core name-based saver loading object-based
+ # checkpoints once object-based checkpointing is in core.
+
+ def testSaveGraphLoadEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ save_path = root.save(
+ session=session, file_prefix=checkpoint_prefix)
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed()
+ self._check_sentinels(root)
+
+ def testSaveEagerLoadGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.eager_mode():
+ root = self._initialized_model()
+ save_path = root.save(file_prefix=checkpoint_prefix)
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph):
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed().run_restore_ops()
+ self._check_sentinels(root)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
new file mode 100644
index 0000000000..bd5557f4fd
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -0,0 +1,1349 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Version 2 of class Optimizer."""
+# pylint: disable=g-bad-name
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import optimizer as optimizer_v1
+from tensorflow.python.training import slot_creator
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
+
+
+class _OptimizableVariable(object):
+ """Interface for abstracting over variables in the optimizers."""
+
+ @abc.abstractmethod
+ def target(self):
+ """Returns the optimization target for this variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+ @abc.abstractmethod
+ def update_op(self, optimizer, g, *args):
+ """Returns the update ops for updating the variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+
+class _RefVariableProcessor(_OptimizableVariable):
+ """Processor for Variable."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v._ref() # pylint: disable=protected-access
+
+ def update_op(self, optimizer, g, *args):
+ if isinstance(g, ops.Tensor):
+ update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+ else:
+ assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
+ "tensor nor IndexedSlices.")
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ # pylint: disable=protected-access
+ return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
+
+
+class _DenseReadResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _DenseResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ if isinstance(g, ops.IndexedSlices):
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ return optimizer._resource_apply_sparse_duplicate_indices(
+ g.values, self._v, g.indices, *args)
+ update_op = optimizer._resource_apply_dense(g, self._v, *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _TensorProcessor(_OptimizableVariable):
+ """Processor for ordinary Tensors.
+
+ Even though a Tensor can't really be updated, sometimes it is useful to
+ compute the gradients with respect to a Tensor using the optimizer. Updating
+ the Tensor is, of course, unsupported.
+ """
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ raise NotImplementedError("Trying to update a Tensor ", self._v)
+
+
+def _get_processor(v):
+ """The processor of v."""
+ if context.executing_eagerly():
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ else:
+ return _DenseResourceVariableProcessor(v)
+ if v.op.type == "VarHandleOp":
+ return _DenseResourceVariableProcessor(v)
+ if isinstance(v, variables.Variable):
+ return _RefVariableProcessor(v)
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ raise NotImplementedError("Trying to optimize unsupported type ", v)
+
+
+def _var_key_v2(var):
+ """Key for representing a primary variable, for looking up slots."""
+ # pylint: disable=protected-access
+ if hasattr(var, "_distributed_container"):
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
+ if context.executing_eagerly():
+ return distributed_container._unique_id
+ return distributed_container._shared_name
+ if context.executing_eagerly():
+ return var._unique_id
+ return var.op.name
+
+
+def _resolve(value, name):
+ if callable(value):
+ value = value()
+ return ops.convert_to_tensor(value, name=name)
+
+
+def _is_dynamic(value):
+ """Returns true if __init__ arg `value` should be re-evaluated each step."""
+ if callable(value): return True
+ # Don't need to do anything special in graph mode, since dynamic values
+ # will propagate correctly automatically.
+ # TODO(josh11b): Add per-device caching across steps using variables for
+ # truly static values once we add distributed support.
+ if context.executing_eagerly() and isinstance(
+ value, resource_variable_ops.ResourceVariable):
+ return True
+ return False
+
+
+class _OptimizerV2State(object):
+ """Holds per-graph and per-step optimizer state.
+
+ Use _init_with_static_hyper() to create the state for a graph, and then
+ _copy_with_dynamic_hyper() to convert that to state for a particular step.
+ The difference between the two is that the former only has hyper
+ parameter values that are static and the latter also has values that
+ can change every step (according to _is_dynamic()).
+ """
+
+ def __init__(self, op_name):
+ self._op_name = op_name
+
+ def _init_with_static_hyper(self, hyper):
+ """Initialize a fresh state object from hyper dict."""
+ # self._hyper contains a dict from name to a dict with the Tensor values.
+ # This dict starts with a single item with key "None" with the hyper
+ # parameter value converted to a Tensor. Other items have dtype keys
+ # with that Tensor cast to that dtype.
+ with ops.init_scope():
+ self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
+ self._slots = {}
+ self._non_slot_dict = {}
+ # Extra state to help Optimizers implement Checkpointable. Holds information
+ # about variables which will be restored as soon as they're created.
+ self._deferred_dependencies = {} # Non-slot variables
+ self._deferred_slot_restorations = {} # Slot variables
+
+ def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
+ """Create a new state object for a particular step."""
+ ret = _OptimizerV2State(self._op_name)
+ # pylint: disable=protected-access
+ ret._slots = self._slots
+ ret._non_slot_dict = self._non_slot_dict
+ ret._deferred_dependencies = self._deferred_dependencies
+ ret._deferred_slot_restorations = self._deferred_slot_restorations
+ ret._hyper = {name: {None: _resolve(value, name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
+ ret._hyper.update(self._hyper)
+ ret._non_slot_devices = non_slot_devices
+ ret._distribution = distribution
+ return ret
+
+ def _variables(self):
+ """Returns a list of all variables held by self."""
+ optimizer_variables = list(self._non_slot_dict.values())
+ for variable_dict in self._slots.values():
+ for slot_for_variable in variable_dict.values():
+ optimizer_variables.append(slot_for_variable)
+ # Sort variables by name so that the return is deterministic.
+ return sorted(optimizer_variables, key=lambda v: v.name)
+
+ def _slot_dict(self, slot_name):
+ """Returns a dict for caching slots created under the given name.
+
+ Args:
+ slot_name: Name for the slot.
+
+ Returns:
+ A dict that maps primary `Variable` objects to the slot created
+ for that variable, under the given slot name.
+ """
+ named_slots = self._slots.get(slot_name, None)
+ if named_slots is None:
+ named_slots = {}
+ self._slots[slot_name] = named_slots
+ return named_slots
+
+ def create_slot(self, var, val, slot_name, optional_op_name=None):
+ """Find or create a slot for a variable.
+
+ Args:
+ var: A `Variable` object.
+ val: A `Tensor`. The initial value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot(
+ var, val, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def create_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, optional_op_name=None):
+ """Find or create a slot for a variable, using an Initializer.
+
+ Args:
+ var: A `Variable` object.
+ initializer: An `Initializer`. The initial value of the slot.
+ shape: Shape of the initial value of the slot.
+ dtype: Type of the value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot_with_initializer(
+ var, initializer, shape, dtype, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def zeros_slot(self, var, slot_name, optional_op_name=None):
+ """Find or create a slot initialized with 0.0.
+
+ Args:
+ var: A `Variable` object.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_zeros_slot(
+ var, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable,
+ optional_op_name=None):
+ """Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored. When executing eagerly, we create the slot variable with a
+ restoring initializer.
+
+ No new variables are created when graph building. Instead,
+ _restore_slot_variable catches these after normal creation and adds restore
+ ops to the graph. This method is nonetheless important when graph building
+ for the case when a slot variable has already been created but `variable`
+ has just been added to a dependency graph (causing us to realize that the
+ slot variable needs to be restored).
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+ """
+ slot_variable = self.get_slot(var=variable, name=slot_name)
+ if (slot_variable is None and context.executing_eagerly() and
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
+ initializer = checkpointable.CheckpointInitialValue(
+ checkpoint_position=slot_variable_position)
+ slot_variable = self.create_slot(
+ var=variable,
+ val=initializer,
+ slot_name=slot_name,
+ optional_op_name=optional_op_name)
+ # Optimizers do not have unconditional dependencies on their slot
+ # variables (nor do any other objects). They are only saved if the
+ # variables they were created for are also saved.
+ if slot_variable is not None:
+ # If we've either made this slot variable, or if we've pulled out an
+ # existing slot variable, we should restore it.
+ slot_variable_position.restore(slot_variable)
+ else:
+ # We didn't make the slot variable. Defer restoring until it gets created
+ # normally. We keep a list rather than the one with the highest restore
+ # UID in case slot variables have their own dependencies, in which case
+ # those could differ between restores.
+ variable_key = _var_key_v2(variable)
+ self._deferred_slot_restorations.setdefault(
+ slot_name, {}).setdefault(variable_key, []).append(
+ slot_variable_position)
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ named_slots = self._slots.get(name, None)
+ if not named_slots:
+ return None
+ return named_slots.get(_var_key_v2(var), None)
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ return sorted(self._slots.keys())
+
+ def create_non_slot(self, initial_value, name, colocate_with=None):
+ """Add an extra variable, not associated with a slot."""
+ v = self._non_slot_dict.get(name, None)
+ if v is None:
+ if colocate_with is None: colocate_with = self._non_slot_devices
+ with self._distribution.colocate_vars_with(colocate_with):
+ # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
+ v = variable_scope.variable(initial_value, name=name, trainable=False)
+ self._non_slot_dict[name] = v
+ deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
+ for checkpoint_position in sorted(
+ deferred_dependencies_list,
+ key=lambda restore: restore.checkpoint.restore_uid,
+ reverse=True):
+ checkpoint_position.restore(v)
+ return v
+
+ def _restore_slot_variable(self, slot_name, variable, slot_variable):
+ """Restore a newly created slot variable's value."""
+ variable_key = _var_key_v2(variable)
+ deferred_restorations = self._deferred_slot_restorations.get(
+ slot_name, {}).pop(variable_key, [])
+ # Iterate over restores, highest restore UID first to minimize the number
+ # of assignments.
+ deferred_restorations.sort(key=lambda position: position.restore_uid,
+ reverse=True)
+ for checkpoint_position in deferred_restorations:
+ checkpoint_position.restore(slot_variable)
+
+ def get_non_slot(self, name):
+ """Returns the non-slot variable identified by `name`."""
+ return self._non_slot_dict.get(name, None)
+
+ def get_hyper(self, name, dtype=None):
+ """Returns the `name` hyper parameter, optionally cast to `dtype`."""
+ dtype_dict = self._hyper[name]
+ # Do we have the value cast to dtype already cached? This should always
+ # succeed when dtype is None.
+ if dtype in dtype_dict:
+ return dtype_dict[dtype]
+ # Not cached, cast to dtype and save the result in the cache.
+ result = math_ops.cast(dtype_dict[None], dtype)
+ dtype_dict[dtype] = result
+ return result
+
+
+class OptimizerV2(optimizer_v1.Optimizer):
+ """Updated base class for optimizers.
+
+ This class defines the API to add Ops to train a model. You never use this
+ class directly, but instead instantiate one of its subclasses such as
+ `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
+
+ ### Usage
+
+ ```python
+ # Create an optimizer with the desired parameters.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+ # Add Ops to the graph to minimize a cost by updating a list of variables.
+ # "cost" is a Tensor, and the list of variables contains tf.Variable
+ # objects.
+ opt_op = opt.minimize(cost, var_list=<list of variables>)
+ ```
+
+ In the training program you will just have to run the returned Op.
+
+ ```python
+ # Execute opt_op to do one step of training:
+ opt_op.run()
+ ```
+
+ ### Processing gradients before applying them.
+
+ Calling `minimize()` takes care of both computing the gradients and
+ applying them to the variables. If you want to process the gradients
+ before applying them you can instead use the optimizer in three steps:
+
+ 1. Compute the gradients with `compute_gradients()`.
+ 2. Process the gradients as you wish.
+ 3. Apply the processed gradients with `apply_gradients()`.
+
+ Example:
+
+ ```python
+ # Create an optimizer.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Compute the gradients for a list of variables.
+ grads_and_vars = opt.compute_gradients(loss, <list of variables>)
+
+ # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
+ # need to the 'gradient' part, for example cap them, etc.
+ capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
+
+ # Ask the optimizer to apply the capped gradients.
+ opt.apply_gradients(capped_grads_and_vars)
+ ```
+
+ ### Gating Gradients
+
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
+ argument that controls the degree of parallelism during the application of
+ the gradients.
+
+ The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
+
+ <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
+ the maximum parallelism in execution, at the cost of some non-reproducibility
+ in the results. For example the two gradients of `matmul` depend on the input
+ values: With `GATE_NONE` one of the gradients could be applied to one of the
+ inputs _before_ the other gradient is computed resulting in non-reproducible
+ results.
+
+ <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
+ they are used. This prevents race conditions for Ops that generate gradients
+ for multiple inputs where the gradients depend on the inputs.
+
+ <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
+ before any one of them is used. This provides the least parallelism but can
+ be useful if you want to process all gradients before applying any of them.
+
+ ### Slots
+
+ Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
+ allocate and manage additional variables associated with the variables to
+ train. These are called <i>Slots</i>. Slots have names and you can ask the
+ optimizer for the names of the slots that it uses. Once you have a slot name
+ you can ask the optimizer for the variable it created to hold the slot value.
+
+ This can be useful if you want to log debug a training algorithm, report stats
+ about the slots, etc.
+
+ ### Non-slot variables
+
+ Some optimizer subclasses, such as `AdamOptimizer` have variables that
+ are not associated with the variables to train, just the step itself.
+
+ ### Hyper parameters
+
+ These are arguments passed to the optimizer subclass constructor
+ (the `__init__` method), and then passed to `self._set_hyper()`.
+ They can be either regular Python values (like 1.0), tensors, or
+ callables. If they are callable, the callable will be called during
+ `apply_gradients()` to get the value for the hyper parameter.
+
+ ### State
+
+ Internal methods are passed a `state` argument with the correct
+ values to use for the slot and non-slot variables, and the hyper
+ parameters.
+ """
+
+ # Values for gate_gradients.
+ GATE_NONE = 0
+ GATE_OP = 1
+ GATE_GRAPH = 2
+
+ def __init__(self, name):
+ """Create a new Optimizer.
+
+ This must be called by the constructors of subclasses.
+ Note that Optimizer instances should not bind to a single graph,
+ and so shouldn't keep Tensors as member variables. Generally
+ you should be able to use the _set_hyper()/state.get_hyper()
+ facility instead.
+
+ Args:
+ name: A non-empty string. The name to use for accumulators created
+ for the optimizer.
+
+ Raises:
+ ValueError: If name is malformed.
+ RuntimeError: If _create_slots has been overridden instead of
+ _create_vars.
+ """
+ # Note: We intentionally don't call parent __init__.
+
+ # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
+ if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
+ OptimizerV2._create_slots.__code__):
+ raise RuntimeError("Override _create_vars instead of _create_slots when "
+ "descending from OptimizerV2 (class %s)" %
+ self.__class__.__name__)
+ if not name:
+ raise ValueError("Must specify the optimizer name")
+
+ self._use_locking = False
+ self._name = name
+ # Map from graph_key to state for that graph. We use the graph_key
+ # since it works in both eager and graph mode, and gives the outer
+ # graph inside functions.
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context is None:
+ # In a cross-tower context for a DistributionStrategy, which means
+ # only one Optimizer will be created, not one per tower.
+ self._per_graph_state = {}
+ else:
+ # We use get_tower_context().merge_call() to get a single dict
+ # shared across all model replicas when running with a
+ # DistributionStrategy.
+ self._per_graph_state = tower_context.merge_call(lambda _: {})
+
+ # Hyper parameters, and whether they should be re-evaluated every step.
+ self._hyper = {}
+
+ def _set_hyper(self, name, value):
+ self._hyper[name] = (_is_dynamic(value), value)
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=GATE_OP, aggregation_method=None,
+ colocate_gradients_with_ops=False, name=None,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Add operations to minimize `loss` by updating `var_list`.
+
+ This method simply combines calls `compute_gradients()` and
+ `apply_gradients()`. If you want to process the gradient before applying
+ them call `compute_gradients()` and `apply_gradients()` explicitly instead
+ of using this function.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ Raises:
+ ValueError: If some of the variables are not `Variable` objects.
+
+ @compatibility(eager)
+ When eager execution is enabled, `loss` should be a Python function that
+ takes elements of `var_list` as arguments and computes the value to be
+ minimized. If `var_list` is None, `loss` should take no arguments.
+ Minimization (and gradient computation) is done with respect to the
+ elements of `var_list` if not None, else with respect to any trainable
+ variables created during the execution of the `loss` function.
+ `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
+ `grad_loss` are ignored when eager execution is enabled.
+ @end_compatibility
+ """
+ grads_and_vars = self.compute_gradients(
+ loss, var_list=var_list, gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss, stop_gradients=stop_gradients,
+ scale_loss_by_num_towers=scale_loss_by_num_towers)
+
+ vars_with_grad = [v for g, v in grads_and_vars if g is not None]
+ if not vars_with_grad:
+ raise ValueError(
+ "No gradients provided for any variable, check your graph for ops"
+ " that do not support gradients, between variables %s and loss %s." %
+ ([str(v) for _, v in grads_and_vars], loss))
+
+ return self.apply_gradients(grads_and_vars, global_step=global_step,
+ name=name)
+
+ def compute_gradients(self, loss, var_list=None,
+ gate_gradients=GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Compute gradients of `loss` for the variables in `var_list`.
+
+ This is the first part of `minimize()`. It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a `Tensor`, an
+ `IndexedSlices`, or `None` if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize or a callable taking
+ no arguments which returns the value to minimize. When eager execution
+ is enabled it must be a callable.
+ var_list: Optional list or tuple of `tf.Variable` to update to minimize
+ `loss`. Defaults to the list of variables collected in the graph
+ under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ A list of (gradient, variable) pairs. Variable is always present, but
+ gradient can be `None`.
+
+ Raises:
+ TypeError: If `var_list` contains anything else than `Variable` objects.
+ ValueError: If some arguments are invalid.
+ RuntimeError: If called with eager execution enabled and `loss` is
+ not callable.
+
+ @compatibility(eager)
+ When eager execution is enabled, `gate_gradients`, `aggregation_method`,
+ and `colocate_gradients_with_ops` are ignored.
+ @end_compatibility
+ """
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if callable(loss):
+ with backprop.GradientTape() as tape:
+ if var_list is not None:
+ tape.watch(var_list)
+ loss_value = loss()
+
+ # Scale loss for number of towers (callable-loss case). In this case,
+ # we have to be careful to call distribute_lib.get_loss_reduction()
+ # *after* loss() is evaluated, so we know what loss reduction it uses.
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss_value *= 1. / num_towers
+
+ if var_list is None:
+ var_list = tape.watched_variables()
+ grads = tape.gradient(loss_value, var_list, grad_loss)
+ return list(zip(grads, var_list))
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "`loss` passed to Optimizer.compute_gradients should "
+ "be a function when eager execution is enabled.")
+
+ # Scale loss for number of towers (non-callable-loss case).
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss *= 1. / num_towers
+
+ if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
+ optimizer_v1.Optimizer.GATE_OP,
+ optimizer_v1.Optimizer.GATE_GRAPH]:
+ raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
+ "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
+ gate_gradients)
+ self._assert_valid_dtypes([loss])
+ if grad_loss is not None:
+ self._assert_valid_dtypes([grad_loss])
+ if var_list is None:
+ var_list = (
+ variables.trainable_variables() +
+ ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ else:
+ var_list = nest.flatten(var_list)
+ # pylint: disable=protected-access
+ var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
+ # pylint: enable=protected-access
+ processors = [_get_processor(v) for v in var_list]
+ if not var_list:
+ raise ValueError("No variables to optimize.")
+ var_refs = [p.target() for p in processors]
+ grads = gradients.gradients(
+ loss, var_refs, grad_ys=grad_loss,
+ gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ stop_gradients=stop_gradients)
+ if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
+ grads = control_flow_ops.tuple(grads)
+ grads_and_vars = list(zip(grads, var_list))
+ self._assert_valid_dtypes(
+ [v for g, v in grads_and_vars
+ if g is not None and v.dtype != dtypes.resource])
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+
+ Raises:
+ TypeError: If `grads_and_vars` is malformed.
+ ValueError: If none of the variables have gradients.
+ """
+ # This is a default implementation of apply_gradients() that can be shared
+ # by most optimizers. It relies on the subclass implementing the following
+ # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
+
+ # Filter out variables with gradients of `None`.
+ grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
+ if not grads_and_vars:
+ raise ValueError("No variables provided.")
+ filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([str(v) for _, v in grads_and_vars],))
+ return distribution_strategy_context.get_tower_context().merge_call(
+ self._distributed_apply, filtered, global_step=global_step, name=name)
+
+ def _get_or_create_state(self, var_list=None):
+ """Either looks up or creates `_OptimizerV2State`.
+
+ If any variables are available, they should be passed via the `var_list`
+ argument, and these will be used to determine the graph to create/retrieve
+ state for. Otherwise the returned state is for the current default graph.
+
+ Args:
+ var_list: A list of variables to extract a graph from.
+
+ Returns:
+ An `_OptimizerV2State` object.
+ """
+ # Determine the graph_key from the current graph.
+ eager_execution = context.executing_eagerly()
+ if eager_execution or var_list is None:
+ graph = ops.get_default_graph()
+ else:
+ graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
+ assert graph is not None
+ graph_key = graph._graph_key # pylint: disable=protected-access
+
+ # Get the per graph state by looking up the graph_key.
+ if graph_key in self._per_graph_state:
+ per_graph_state = self._per_graph_state[graph_key]
+ else:
+ per_graph_state = _OptimizerV2State(self._name)
+ per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
+ self._per_graph_state[graph_key] = per_graph_state
+ return per_graph_state
+
+ def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
+ """`apply_gradients` for use with a `DistributionStrategy`."""
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
+ var_list = [v for _, v in grads_and_vars]
+ grads_and_vars = zip(reduced_grads, var_list)
+
+ unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
+ eager_execution = context.executing_eagerly()
+ if eager_execution:
+ # Give a clear error in this case instead of "name not supported
+ # for Eager Tensors" when we compute non_slot_devices.
+ for v in unwrapped_var_list:
+ if isinstance(v, ops.Tensor):
+ raise NotImplementedError("Trying to update a Tensor ", v)
+
+ with ops.name_scope(name, self._name) as name:
+ per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
+ # Include the current value of any dynamic hyper parameters in `state`.
+ non_slot_devices = distribution.non_slot_devices(var_list)
+ state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
+ self._hyper, distribution, non_slot_devices)
+
+ # Create any slot and non-slot variables we need in `state`.
+ with ops.init_scope():
+ self._create_vars(var_list, state)
+
+ with ops.name_scope(name): # Re-enter name_scope created above
+ # Give the child class a chance to do something before we start
+ # applying gradients.
+ self._prepare(state)
+
+ def update(v, g):
+ """Update variable `v` using gradient `g`."""
+ assert v is not None
+
+ # Convert the grad to Tensor or IndexedSlices if necessary, and
+ # look up a processor for each variable's type.
+ try:
+ g = ops.convert_to_tensor_or_indexed_slices(g)
+ except TypeError:
+ raise TypeError(
+ "Gradient must be convertible to a Tensor"
+ " or IndexedSlices, or None: %s" % g)
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ processor = _get_processor(v)
+
+ # We colocate all ops created in _apply_dense or _apply_sparse
+ # on the same device as the variable.
+ # TODO(apassos): figure out how to get the variable name here.
+ scope_name = "" if eager_execution else v.op.name
+ # device_policy is set because non-mirrored tensors will be read in
+ # `update_op`.
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with ops.name_scope("update_" + scope_name), \
+ context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return processor.update_op(self, g, state)
+
+ # Use the processors to update the variables.
+ update_ops = []
+ for grad, var in grads_and_vars:
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
+
+ # Give the child class a chance to do something after applying
+ # gradients
+ def finish():
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return self._finish(state)
+
+ update_ops = control_flow_ops.group(update_ops)
+ with ops.control_dependencies([update_ops]):
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
+
+ # Update `global_step` (if any).
+ if global_step is None:
+ apply_updates = distribution.group(finish_updates, name=name)
+ else:
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(global_step, update_global_step,
+ name)
+
+ # Add the training op to the TRAIN_OP graph collection in graph mode.
+ if not eager_execution:
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ state = self._get_state_for_var(var)
+ return state.get_slot(var, name) if state is not None else None
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ state = self._get_per_graph_state()
+ return state.get_slot_names() if state is not None else []
+
+ def variables(self):
+ """A list of variables which encode the current state of `Optimizer`.
+
+ Includes slot variables and additional global variables created by the
+ optimizer in the current default graph.
+
+ Returns:
+ A list of variables.
+ """
+ state = self._get_per_graph_state()
+ return state._variables() if state is not None else [] # pylint: disable=protected-access
+
+ # --------------
+ # Methods to be implemented by subclasses if they want to use the
+ # inherited implementation of apply_gradients() or compute_gradients().
+ # --------------
+ def _create_vars(self, var_list, state):
+ """Create all slots needed by the variables and any non-slot variables.
+
+ Args:
+ var_list: A list of `Variable` objects.
+ state: An object with these methods:
+ `create_slot(var, val, slot_name, optional_op_name)`,
+ `create_slot_with_initializer(`
+ `var, initializer, shape, dtype, slot_name, optional_op_name)`,
+ `zeros_slot(var, slot_name, optional_op_name)`,
+ `create_non_slot_variable(initial_value, name, colocate_with)`,
+ `get_hyper(name)`
+ """
+ # No slots needed by default
+ pass
+
+ def _prepare(self, state):
+ """Code to execute before applying gradients.
+
+ Note that most uses of _prepare() in Optimizer have been subsumed
+ by explicit support for hyper parameters in OptimizerV2
+
+ Args:
+ state: An object with a `get_hyper(name)` method.
+
+ Returns:
+ Return value will be ignored.
+ """
+ pass
+
+ def _apply_dense(self, grad, var, state):
+ """Add ops to apply dense gradients to `var`.
+
+ Args:
+ grad: A `Tensor`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_dense(self, grad, handle, state):
+ """Add ops to apply dense gradients to the variable `handle`.
+
+ Args:
+ grad: a `Tensor` representing the gradient.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_sparse_duplicate_indices(
+ self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to `handle`, with repeated indices.
+
+ Optimizers which override this method must deal with repeated indices. See
+ the docstring of `_apply_sparse_duplicate_indices` for details. By default
+ the correct behavior, to sum non-unique indices and their associated
+ gradients, is enforced by first pre-processing `grad` and `indices` and
+ passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
+ with duplicate indices may instead override this method to avoid the
+ overhead of summing.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices may be repeated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ # pylint: disable=protected-access
+ summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad, indices=indices)
+ # pylint: enable=protected-access
+ return self._resource_apply_sparse(
+ summed_grad, handle, unique_indices, state)
+
+ def _resource_apply_sparse(self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to the variable `handle`.
+
+ Similar to `_apply_sparse`, the `indices` argument to this method has been
+ de-duplicated. Optimizers which deal correctly with non-unique indices may
+ instead override `_resource_apply_sparse_duplicate_indices` to avoid this
+ overhead.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices are unique.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
+
+ Optimizers which override this method must deal with IndexedSlices objects
+ such as the following:
+
+ IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
+
+ The correct interpretation is:
+
+ IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
+
+ Many optimizers deal incorrectly with repeated indices when updating based
+ on sparse gradients (e.g. summing squares rather than squaring the sum, or
+ applying momentum terms multiple times). Adding first is always the correct
+ behavior, so this is enforced here by reconstructing the IndexedSlices to
+ have only unique indices, then calling _apply_sparse.
+
+ Optimizers which deal correctly with repeated indices may instead override
+ this method to avoid the overhead of summing indices.
+
+ Args:
+ grad: `IndexedSlices`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ # pylint: disable=protected-access
+ summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad.values, indices=grad.indices)
+ # pylint: enable=protected-access
+ gradient_no_duplicate_indices = ops.IndexedSlices(
+ indices=unique_indices,
+ values=summed_values,
+ dense_shape=grad.dense_shape)
+ return self._apply_sparse(gradient_no_duplicate_indices, var, state)
+
+ def _apply_sparse(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`.
+
+ The IndexedSlices object passed to `grad` in this function is by default
+ pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
+ indices (see its docstring for details). Optimizers which can tolerate or
+ have correct special cases for duplicate sparse indices may override
+ `_apply_sparse_duplicate_indices` instead of this function, avoiding that
+ overhead.
+
+ Args:
+ grad: `IndexedSlices`, with no repeated indices.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _finish(self, state):
+ """Do what is needed to finish the update.
+
+ This is called inside a scope colocated with any non-slot variables.
+
+ Args:
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ The operation to apply updates, or None if no updates.
+ """
+ return None
+
+ # --------------
+ # Utility methods for subclasses.
+ # --------------
+ def _get_per_graph_state(self):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
+
+ def _get_state_for_var(self, var):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(var._graph_key, None)
+
+ # --------------
+ # Overridden methods from Checkpointable.
+ # --------------
+
+ def _track_checkpointable(self, *args, **kwargs):
+ """Optimizers may not track dependencies. Raises an error."""
+ raise NotImplementedError(
+ "Optimizers may not have dependencies. File a feature request if this "
+ "limitation bothers you.")
+
+ @property
+ def _checkpoint_dependencies(self):
+ """From Checkpointable. Gather graph-specific non-slot variables to save."""
+ current_graph_non_slot_variables = []
+ state = self._get_per_graph_state()
+ if state is not None:
+ for name, variable_object in sorted(
+ state._non_slot_dict.items(), # pylint: disable=protected-access
+ # Avoid comparing variables
+ key=lambda item: item[0]):
+ current_graph_non_slot_variables.append(
+ checkpointable.CheckpointableReference(
+ name=name, ref=variable_object))
+ # Note: ignores super(); Optimizers may not have any dependencies outside of
+ # state objects.
+ return current_graph_non_slot_variables
+
+ def _lookup_dependency(self, name):
+ """From Checkpointable. Find a non-slot variable in the current graph."""
+ state = self._get_per_graph_state()
+ if state is None:
+ return None
+ else:
+ return state.get_non_slot(name)
+
+ @property
+ def _deferred_dependencies(self):
+ """Lets Checkpointable know where non-slot variables are created.
+
+ If necessary, creates a new state object for the current default graph.
+ Checkpointable will then add entries to that state's deferred dependency
+ dictionary. The state object will check that dictionary when creating
+ non-slot variables, restoring their value if an entry is found.
+
+ Returns:
+ A dictionary which holds deferred dependencies for the current default
+ graph.
+ """
+ state = self._get_or_create_state()
+ return state._deferred_dependencies # pylint: disable=protected-access
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable):
+ """Checkpointable: Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored.
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ """
+ state = self._get_or_create_state(var_list=[variable])
+ state._create_or_restore_slot_variable( # pylint: disable=protected-access
+ slot_variable_position=slot_variable_position,
+ slot_name=slot_name,
+ variable=variable,
+ optional_op_name=self._name)
+
+ # --------------
+ # Unsupported parent methods
+ # --------------
+ def _slot_dict(self, slot_name):
+ raise NotImplementedError(
+ "_slot_dict() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot(self, var, val, slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot_with_initializer() method unsupported in "
+ "OptimizerV2")
+
+ def _create_non_slot_variable(self, initial_value, name, colocate_with):
+ raise NotImplementedError(
+ "_create_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _get_non_slot_variable(self, name, graph=None):
+ raise NotImplementedError(
+ "_get_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _non_slot_variables(self):
+ raise NotImplementedError(
+ "_non_slot_variables() method unsupported in OptimizerV2")
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
new file mode 100644
index 0000000000..a6c939393e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -0,0 +1,277 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional test for OptimizerV2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class OptimizerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBasic(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ # Note that for eager execution, minimize expects a function instead of a
+ # Tensor.
+ global_step = resource_variable_ops.ResourceVariable(
+ array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
+ sgd_op = sgd.SGD(3.0)
+
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Run 1 step of sgd through optimizer
+ opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
+ self.evaluate(opt_op)
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ def testAggregationMethod(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost,
+ global_step, [var0, var1],
+ aggregation_method=gradients_impl.AggregationMethod.
+ EXPERIMENTAL_ACCUMULATE_N)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
+
+ def testPrecomputedGradient(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ grad_loss = constant_op.constant([42, -42], dtype=dtype)
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost, global_step, [var0, var1], grad_loss=grad_loss)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
+ var0.eval())
+ self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
+ var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoVariables(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, trainable=False, name='a')
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, trainable=False, name='b')
+ return 5 * var0 + var1
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No.*variables'):
+ sgd_op.minimize(loss)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ return 5 * var0
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No gradients'):
+ # var1 has no gradient
+ sgd_op.minimize(loss, var_list=[var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_Minimize(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return constant_op.constant(5.0)
+
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.minimize(loss, var_list=[var0, var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_ApplyGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.apply_gradients([(None, var0), (None, var1)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGradientsAsVariables(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
+ # Convert gradients to tf.Variables
+ converted_grads = [
+ resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
+ name='c_%d_%d' % (i, j))
+ for j, gv in enumerate(grads_and_vars)
+ ]
+ convert_ops = [
+ state_ops.assign(converted_grads[j], gv[0])
+ for j, gv in enumerate(grads_and_vars)
+ ]
+
+ self.evaluate(variables.global_variables_initializer())
+ # Run convert_ops to achieve the gradietns converting
+ self.evaluate(convert_ops)
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 1 step of sgd through optimizer
+ converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+ opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
+ self.evaluate(opt_op)
+
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testComputeGradientsWithTensors(self):
+ x = ops.convert_to_tensor(1.0)
+ def f():
+ return x * x
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(f, [x])
+ self.assertEqual(1, len(grads_and_vars))
+ grad, x_as_var = grads_and_vars[0]
+ self.assertIs(x, x_as_var)
+ self.assertEqual(2.0, self.evaluate(grad))
+
+ with self.assertRaises(NotImplementedError):
+ sgd_op.apply_gradients(grads_and_vars)
+
+ def testTrainOp(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0])
+ var1 = variables.Variable([3.0, 4.0])
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+ self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
+
+ def testConstraint(self):
+ constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
+ constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0],
+ constraint=constraint_01)
+ var1 = variables.Variable([3.0, 4.0],
+ constraint=constraint_0)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-0.1, -0.1], var0.eval())
+ self.assertAllClose([0., 0.], var1.eval())
+
+ def testStopGradients(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], name='var0')
+ var1 = variables.Variable([3.0, 4.0], name='var1')
+ var0_id = array_ops.identity(var0)
+ cost = 5 * var0_id + 3 * var1
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1],
+ stop_gradients=[var0_id])
+ grad_dict = {var.op.name: grad for grad, var in grads_and_vars}
+ self.assertIsNone(grad_dict['var0'])
+ self.assertIsNotNone(grad_dict['var1'])
+
+ def testDoNotOverrideCreateSlots(self):
+ class ShouldNotOverrideCreateSlots(optimizer_v2.OptimizerV2):
+
+ def _create_slots(self, var_list):
+ """In OptimizerV2 _create_slots was renamed _create_vars."""
+ return var_list
+
+ with self.assertRaises(RuntimeError):
+ ShouldNotOverrideCreateSlots('name')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
new file mode 100644
index 0000000000..2748d8eff7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -0,0 +1,239 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RMSprop optimizer for Tensorflow.
+
+rmsprop algorithm [tieleman2012rmsprop]
+
+A detailed description of rmsprop.
+
+- maintain a moving (discounted) average of the square of gradients
+- divide gradient by the root of this average
+
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square)
+delta = - mom
+
+This implementation of RMSProp uses plain momentum, not Nesterov momentum.
+
+The centered version additionally maintains a moving (discounted) average of the
+gradients, and uses that average to estimate the variance:
+
+mean_grad = rho * mean_square{t-1} + (1-rho) * gradient
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t /
+ sqrt(mean_square - mean_grad**2)
+delta = - mom
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+
+from tensorflow.python.training import training_ops
+
+
+class RMSProp(optimizer_v2.OptimizerV2):
+ """RMSProp optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values (except the learning rate, which can be freely tuned).
+
+ This optimizer is usually a good choice for recurrent neural networks.
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense implementation of this algorithm, variables and their
+ corresponding accumulators (momentum, gradient moving average, square
+ gradient moving average) will be updated even if the gradient is zero
+ (i.e. accumulators will decay, momentum will be applied). The sparse
+ implementation (used when the gradient is an `IndexedSlices` object,
+ typically because of `tf.gather` or an embedding lookup in the forward pass)
+ will not update variable slices or their accumulators unless those slices
+ were used in the forward pass (nor is there an "eventual" correction to
+ account for these omitted updates). This leads to more efficient updates for
+ large embedding lookup tables (where most of the slices are not accessed in
+ a particular graph execution), but differs from the published algorithm.
+
+ Arguments:
+ learning_rate: A float hyperparameter >= 0. The learning rate.
+ rho: A float hyperparameter >= 0. Discounting factor for the
+ history/coming gradient.
+ momentum: A float hyperparameter >= 0.
+ epsilon: A float hyperparameter >= 0 . Small value to initialize the
+ average square gradient variable and avoid zero denominator.
+ centered: If True, gradients are normalized by the estimated variance of
+ the gradient; if False, by the uncentered second moment. Setting this to
+ True may help with training, but is slightly more expensive in terms of
+ computation and memory. Defaults to False.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "RMSProp".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.9,
+ momentum=None,
+ epsilon=1e-10,
+ centered=False,
+ name="RMSProp"):
+ super(RMSProp, self).__init__(name)
+ # Momentum default is `None` for consistency with SGD
+ # but underlying implementation uses `momentum` hyperparameter here
+ # regardless unlike SGD. Since extneral Keras RMSProp does not have
+ # a `momentum` weight, for compatibility with external Keras h5 files,
+ # when `momentum` was set as `None` we should ignore the `momentum`
+ # variable in `get_weights` and not require it in `set_weights`.
+ if momentum is None:
+ momentum = 0.0
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("momentum", momentum)
+ self._set_hyper("epsilon", epsilon)
+
+ self._centered = centered
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ init_rms = state.get_hyper(
+ "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
+ state.create_slot_with_initializer(v, init_rms, v.get_shape(),
+ v.dtype.base_dtype, "rms")
+ if self._centered:
+ state.zeros_slot(v, "mg")
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ # epsilon is now the rms initial value and is not added to the
+ # denominator anymore, hence calling the kernel op with epsilon=0.
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+ else:
+ return training_ops.apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.resource_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.sparse_apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.sparse_apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = self.get_slot(var, "mg")
+ return training_ops.resource_sparse_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_sparse_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
new file mode 100644
index 0000000000..2c5eccdc5b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
@@ -0,0 +1,444 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rmsprop optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+_DATA_TYPES = [dtypes.half, dtypes.float32]
+
+_TEST_PARAM_VALUES = [
+ # learning_rate, rho, momentum, epsilon, centered, use_resource
+ [0.5, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.9, 0.0, 1.0, False, False],
+ [0.5, 0.9, 0.0, 1.0, True, True],
+ [0.5, 0.9, 0.0, 1.0, False, True],
+ [0.1, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.95, 0.0, 1.0, False, False],
+ [0.5, 0.8, 0.0, 1e-3, True, False],
+ [0.5, 0.8, 0.9, 1e-3, True, False],
+]
+
+
+class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
+
+ def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, rho, momentum,
+ centered):
+ rms_t = rms * rho + (1 - rho) * g * g
+ if centered:
+ mg_t = mg * rho + (1 - rho) * g
+ denom_t = rms_t - mg_t * mg_t
+ else:
+ mg_t = mg
+ denom_t = rms_t
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
+
+ def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
+ lr, rho, momentum, centered):
+ mg_t = copy.deepcopy(mg)
+ rms_t = copy.deepcopy(rms)
+ mom_t = copy.deepcopy(mom)
+ var_t = copy.deepcopy(var)
+ for i in range(len(gindexs)):
+ gindex = gindexs[i]
+ gvalue = gvalues[i]
+ rms_t[gindex] = rms[gindex] * rho + (1 - rho) * gvalue * gvalue
+ denom_t = rms_t[gindex]
+ if centered:
+ mg_t[gindex] = mg_t[gindex] * rho + (1 - rho) * gvalue
+ denom_t -= mg_t[gindex] * mg_t[gindex]
+ mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t)
+ var_t[gindex] = var[gindex] - mom_t[gindex]
+ return var_t, mg_t, rms_t, mom_t
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testDense(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered,
+ use_resource) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, rho,
+ momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, rho,
+ momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=0.01, half_atol=0.01)
+ self.assertAllCloseAccordingToType(
+ var1_np, var1.eval(), half_rtol=0.01, half_atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariable(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.0, momentum=0.0, epsilon=0.0,
+ centered=False).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0., 1.]], var0.eval(), atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariableCentered(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.1, momentum=0.0, epsilon=1.0,
+ centered=True).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-7/3.0, -4/3.0]], var0.eval(), atol=0.01)
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testSparse(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered, _) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([1]))
+ grads1_np_indices = np.array([1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([1]))
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
+ var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
+ learning_rate, rho, momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
+ var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
+ learning_rate, rho, momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithoutMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.5, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: rms = 1, mom = 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the momentum accumulators
+ self.assertAllCloseAccordingToType(
+ np.array([(0.1 * 2.0 / math.sqrt(0.901)),
+ (0.1 * 2.0 / math.sqrt(0.901))]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.01 * 2.0 / math.sqrt(0.90001)),
+ (0.01 * 2.0 / math.sqrt(0.90001))]), mom1.eval())
+
+ # Check that the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), mom1.eval())
+
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)))
+ ]), var0.eval())
+
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)))
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/sgd.py b/tensorflow/python/keras/optimizer_v2/sgd.py
new file mode 100644
index 0000000000..f5583691f7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd.py
@@ -0,0 +1,170 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Momentum for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import training_ops
+
+
+class SGD(optimizer_v2.OptimizerV2):
+ """Stochastic gradient descent optimizer.
+
+ Includes support for momentum and Nesterov momentum.
+
+ Computes (if `nesterov = False`):
+
+ ```
+ accumulation = momentum * accumulation + gradient
+ variable -= learning_rate * accumulation
+ ```
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense version of this algorithm, `accumulation` is updated
+ and applied regardless of a gradient's value, whereas the sparse version (when
+ the gradient is an `IndexedSlices`, typically because of `tf.gather` or an
+ embedding) only updates variable slices and corresponding `accumulation` terms
+ when that part of the variable was used in the forward pass.
+
+ @compatibility(eager)
+ When eager execution is enabled, learning_rate and momentum can each be a
+ callable that takes no arguments and returns the actual value to use. This
+ can be useful for changing these values across different invocations of
+ optimizer functions.
+ @end_compatibility
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ momentum: float hyperparameter >= 0 or None. Parameter that accelerates
+ SGD in the relevant direction and dampens oscillations.
+ nesterov: boolean. Whether to apply Nesterov momentum. See [Sutskever et
+ al., 2013](http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). This
+ implementation always computes gradients at the value of the
+ variable(s) passed to the optimizer. Using Nesterov Momentum makes the
+ variable(s) track the values called `theta_t + mu*v_t` in the paper.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'SGD'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ momentum=None,
+ nesterov=False,
+ name="SGD"):
+ super(SGD, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ # Only create momentum variables and use momentum ops if needed.
+ if momentum is not None:
+ self._set_hyper("momentum", momentum)
+ self._use_nesterov = nesterov
+ self._use_momentum = True
+ else:
+ self._use_momentum = False
+
+ def _create_vars(self, var_list, state):
+ if self._use_momentum:
+ for v in var_list:
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return training_ops.apply_gradient_descent(
+ var,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return training_ops.resource_apply_gradient_descent(
+ var.handle, lr, grad, use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return super(SGD, self)._apply_sparse(grad, var, state)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ return super(SGD, self)._resource_apply_sparse(grad, var, indices, state)
+
+ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, state):
+ if self._use_momentum:
+ return super(SGD, self)._resource_apply_sparse_duplicate_indices(
+ grad, var, indices, state)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return resource_variable_ops.resource_scatter_add(var.handle, indices,
+ -grad * lr)
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ if self._use_momentum:
+ return super(SGD, self)._apply_sparse_duplicate_indices(grad, var, state)
+ else:
+ delta = ops.IndexedSlices(
+ grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.indices, grad.dense_shape)
+ return var.scatter_sub(delta, use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/sgd_test.py b/tensorflow/python/keras/optimizer_v2/sgd_test.py
new file mode 100644
index 0000000000..eb39aac283
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd_test.py
@@ -0,0 +1,759 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Momentum."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class GradientDescentOptimizerTest(test.TestCase):
+
+ def testBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ optimizer = sgd.SGD(3.0)
+ sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertEqual(0, len(optimizer.variables()))
+
+ def testBasicResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testMinimizeResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(var0, x) + var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ pred += var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ lrate = constant_op.constant(3.0)
+ sgd_op = sgd.SGD(lrate).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testGradWrtRef(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ opt = sgd.SGD(3.0)
+ values = [1.0, 3.0]
+ vars_ = [variables.Variable([v], dtype=dtype) for v in values]
+ grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_)
+ variables.global_variables_initializer().run()
+ for grad, _ in grads_and_vars:
+ self.assertAllCloseAccordingToType([1.0], grad.eval())
+
+ def testWithGlobalStep(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ global_step = variables.Variable(0, trainable=False)
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertAllCloseAccordingToType(1, global_step.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
+ var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
+ var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
+
+
+class MomentumOptimizerTest(test.TestCase):
+
+ def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
+ var = var + accum * lr * momentum
+ accum = accum * momentum + g
+ var = var - lr * accum
+ var = var - accum * lr * momentum
+ return var, accum
+
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ learning_rate = lambda: 2.0
+ momentum = lambda: 0.9
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ momentum = momentum()
+ mom_opt = sgd.SGD(learning_rate=learning_rate, momentum=momentum)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ if not context.executing_eagerly():
+ self.assertFalse(slot0 in variables.trainable_variables())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ if not context.executing_eagerly():
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
+ # Step 2: the momentum accumulators contain the previous update.
+ if context.executing_eagerly():
+ mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ else:
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testVariablesAcrossGraphs(self):
+ optimizer = sgd.SGD(0.01, 0.5)
+ with ops.Graph().as_default():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var0")
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var1")
+ loss = math_ops.reduce_sum(var0 + var1)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var0")
+ self.assertStartsWith(optimizer_variables[1].name, "var1")
+ self.assertEquals(2, len(optimizer_variables))
+
+ with ops.Graph().as_default():
+ var2 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var2")
+ var3 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var3")
+ loss = math_ops.reduce_sum(var2 + var3)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var2")
+ self.assertStartsWith(optimizer_variables[1].name, "var3")
+ self.assertEquals(2, len(optimizer_variables))
+
+ def testNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ cost = 5 * var0 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name="global_step")
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ opt_op = mom_op.minimize(cost, global_step, [var0, var1])
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_op.run()
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testSparseNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ grads = []
+ for t in range(1, 5):
+ grads.append(var0_np * 10)
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ loss = 5 * var0 * var0 + 3 * var1
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ x_feed = array_ops.placeholder(dtype)
+ y_feed = ops.IndexedSlices(
+ x_feed, constant_op.constant([0, 1]), constant_op.constant([2]))
+ grads_and_vars = [(y_feed, var0), (constant_op.constant(
+ [3.0, 3.0], dtype=dtype), var1)]
+ opt_update = mom_op.apply_gradients(grads_and_vars)
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_update.run(feed_dict={x_feed: grads[t - 1]})
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ return pred * pred
+ # pylint: enable=cell-var-from-loop
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ # Run 1 step of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+
+ def loss():
+ return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(sgd_op)
+ self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
+
+ def testTensorLearningRateAndMomentum(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(
+ learning_rate=constant_op.constant(2.0),
+ momentum=constant_op.constant(0.9))
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in variables.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+ def _dbParamsMom01(self):
+ """Return dist-belief momentum values.
+
+ Return values been generated from the dist-belief momentum unittest,
+ running with a learning rate of 0.1 and a momentum of 0.1.
+
+ These values record how a parameter vector of size 10, initialized with 0.0,
+ gets updated with 10 consecutive momentum steps. It uses random gradients.
+
+ Returns:
+ db_grad: The gradients to apply
+ db_out: The parameters after the momentum update.
+ """
+ db_grad = [[]] * 10
+ db_out = [[]] * 10
+ # pylint: disable=line-too-long
+ db_grad[0] = [
+ 0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018,
+ 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615
+ ]
+ db_out[0] = [
+ -9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018,
+ -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618
+ ]
+ db_grad[1] = [
+ 0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378,
+ 0.5513742, 0.94687688, 0.16012503, 0.22159521
+ ]
+ db_out[1] = [
+ -0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884,
+ -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544
+ ]
+ db_grad[2] = [
+ 0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965,
+ 0.31168157, 0.43203235, 0.16792089, 0.24644311
+ ]
+ db_out[2] = [
+ -0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978,
+ -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189
+ ]
+ db_grad[3] = [
+ 0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098,
+ 0.81454384, 0.03848977, 0.89759839, 0.93665648
+ ]
+ db_out[3] = [
+ -0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105,
+ -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303
+ ]
+ db_grad[4] = [
+ 0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359,
+ 0.69107032, 0.81897682, 0.5433259, 0.67860287
+ ]
+ db_out[4] = [
+ -0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165,
+ -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544
+ ]
+ db_grad[5] = [
+ 0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563,
+ 0.84163809, 0.41172323, 0.83259648, 0.44941229
+ ]
+ db_out[5] = [
+ -0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094,
+ -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717
+ ]
+ db_grad[6] = [
+ 0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221,
+ 0.73577434, 0.16014607, 0.57500273, 0.071136251
+ ]
+ db_out[6] = [
+ -0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685,
+ -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997
+ ]
+ db_grad[7] = [
+ 0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646,
+ 0.74053431, 0.16033, 0.66625422, 0.73515922
+ ]
+ db_out[7] = [
+ -0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838,
+ -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418
+ ]
+ db_grad[8] = [
+ 0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039,
+ 0.55561525, 0.22567581, 0.93331909, 0.29438227
+ ]
+ db_out[8] = [
+ -0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527,
+ -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781
+ ]
+ db_grad[9] = [
+ 0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893,
+ 0.68593478, 0.50580865, 0.12602448, 0.093537711
+ ]
+ db_out[9] = [
+ -0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302,
+ -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295
+ ]
+ # pylint: enable=line-too-long
+ return db_grad, db_out
+
+ def testLikeDistBeliefMom01(self):
+ with self.cached_session():
+ db_grad, db_out = self._dbParamsMom01()
+ num_samples = len(db_grad)
+ var0 = variables.Variable([0.0] * num_samples)
+ grads0 = constant_op.constant([0.0] * num_samples)
+ mom_opt = sgd.SGD(learning_rate=0.1, momentum=0.1)
+ mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
+ variables.global_variables_initializer().run()
+ for i in xrange(num_samples):
+ mom_update.run(feed_dict={grads0: db_grad[i]})
+ self.assertAllClose(np.array(db_out[i]), var0.eval())
+
+ def testSparse(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
+ var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.1, .1]], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([4, 2]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.01, .01], [.01, .01]], dtype=dtype),
+ constant_op.constant([2, 3]),
+ constant_op.constant([4, 2]))
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([0, 0], var0.eval()[0])
+ self.assertAllClose([0, 0], var0.eval()[1])
+ self.assertAllClose([1, 1], var1.eval()[2])
+
+ # Step 1: the momentum accumulators are 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(np.array([.1, .1]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([.01, .01]), slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([-(0.1 * 2.0), -(0.1 * 2.0)]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.01 * 2.0), 1.0 - (0.01 * 2.0)]), var1.eval()[2])
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ -(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), -(0.1 * 2.0) - (
+ (0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.98 - ((0.9 * 0.01 + 0.01) * 2.0), 0.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval()[2])
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update1 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ mom_update2 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update1.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the second momentum accumulators contain the previous update.
+ mom_update2.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 501b50ba5f..2fae094a1e 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -166,8 +166,9 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
if expected_dim is not None:
if expected_dim != actual_dim:
raise AssertionError(
- 'When testing layer %s, for input %s, found output_shape='
- '%s but expected to find %s.\nFull kwargs: %s' %
+ 'When testing layer %s **after deserialization**, '
+ 'for input %s, found output_shape='
+ '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
(layer_cls.__name__,
x,
actual_output_shape,
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 8ebca1418d..f486e631e5 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -137,26 +137,49 @@ def conv_input_length(output_length, filter_size, padding, stride):
return (output_length - 1) * stride - 2 * pad + filter_size
-def deconv_output_length(input_length, filter_size, padding, stride):
+def deconv_output_length(input_length, filter_size, padding,
+ output_padding=None, stride=0, dilation=1):
"""Determines output length of a transposed convolution given input length.
Arguments:
- input_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
+ input_length: Integer.
+ filter_size: Integer.
+ padding: one of `"same"`, `"valid"`, `"full"`.
+ output_padding: Integer, amount of padding along the output dimension.
+ Can be set to `None` in which case the output length is inferred.
+ stride: Integer.
+ dilation: Integer.
Returns:
The output length (integer).
"""
+ assert padding in {'same', 'valid', 'full'}
if input_length is None:
return None
- input_length *= stride
- if padding == 'valid':
- input_length += max(filter_size - stride, 0)
- elif padding == 'full':
- input_length -= (stride + filter_size - 2)
- return input_length
+
+ # Get the dilated kernel size
+ filter_size = filter_size + (filter_size - 1) * (dilation - 1)
+
+ # Infer length if output padding is None, else compute the exact length
+ if output_padding is None:
+ if padding == 'valid':
+ length = input_length * stride + max(filter_size - stride, 0)
+ elif padding == 'full':
+ length = input_length * stride - (stride + filter_size - 2)
+ elif padding == 'same':
+ length = input_length * stride
+
+ else:
+ if padding == 'same':
+ pad = filter_size // 2
+ elif padding == 'valid':
+ pad = 0
+ elif padding == 'full':
+ pad = filter_size - 1
+
+ length = ((input_length - 1) * stride + filter_size - 2 * pad +
+ output_padding)
+ return length
def normalize_data_format(value):
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e1c49bc852..04b2ea8fe3 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -244,9 +244,24 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
+ # Deduplicate output names to handle Siamese networks.
+ occurrences = {}
+ for n in model.output_names:
+ if n not in occurrences:
+ occurrences[n] = 1
+ else:
+ occurrences[n] += 1
+ conflict_counter = {n: 0 for n, count in occurrences.items() if count > 1}
+ output_names = []
+ for n in model.output_names:
+ if n in conflict_counter:
+ conflict_counter[n] += 1
+ n += '_%d' % conflict_counter[n]
+ output_names.append(n)
+
# Merge outputs under expected scope.
with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
merged = []
- for name, outputs in zip(model.output_names, all_outputs):
+ for name, outputs in zip(output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 3d0351a11f..1780ab6587 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -198,5 +198,31 @@ class TestMultiGPUModel(test.TestCase):
parallel_model.compile(loss='mean_squared_error', optimizer='adam')
parallel_model.train_on_batch(x, y)
+ def test_multi_gpu_with_siamese_network(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.cached_session():
+ input_shape = (3,)
+ nested_model = keras.models.Sequential([
+ keras.layers.Dense(32, input_shape=input_shape),
+ keras.layers.Dense(1)
+ ], name='nested')
+
+ input1 = keras.Input(input_shape)
+ input2 = keras.Input(input_shape)
+ score1 = nested_model(input1)
+ score2 = nested_model(input2)
+ score_sum = keras.layers.Add(name='add')([score1, score2])
+
+ siamese = keras.models.Model(inputs=[input1, input2],
+ outputs=[score_sum, score1, score2],
+ name='siamese')
+ parallel_siamese = keras.utils.multi_gpu_model(siamese, gpus)
+ self.assertEqual(parallel_siamese.output_names,
+ ['add', 'nested_1', 'nested_2'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
index c24e87308b..3763999bff 100644
--- a/tensorflow/python/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -22,7 +22,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.utils.to_categorical')
-def to_categorical(y, num_classes=None):
+def to_categorical(y, num_classes=None, dtype='float32'):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
@@ -31,6 +31,7 @@ def to_categorical(y, num_classes=None):
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
+ dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
@@ -44,7 +45,7 @@ def to_categorical(y, num_classes=None):
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
- categorical = np.zeros((n, num_classes), dtype=np.float32)
+ categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index e055ef1c1b..4e8639dfc8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3255,7 +3255,7 @@ tf_py_test(
tags = ["no_pip"],
)
-tf_py_test(
+cuda_py_test(
name = "cond_v2_test",
size = "medium",
srcs = ["cond_v2_test.py"],
@@ -3272,7 +3272,6 @@ tf_py_test(
"//tensorflow/python:training",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/111656070)
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
index 78b6e38d94..5777a5d097 100644
--- a/tensorflow/python/kernel_tests/benchmark_test.py
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -64,7 +64,7 @@ class TestReportingBenchmark(test.Benchmark):
"other_key": "string"})
def benchmark_times_an_op(self):
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
a = constant_op.constant(0.0)
a_plus_a = a + a
return self.run_op_benchmark(
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 8a58b3f97e..8177cdd454 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
+ def test_shape_function(self):
+ # size must be scalar.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
+ # size must be positive.
+ with self.assertRaisesRegexp(ValueError, "must be non-negative"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
+ # if size is a constant then the shape is known.
+ v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
+ self.assertAllEqual(v1.get_shape().as_list(), [5])
+ # if size is a placeholder then the shape is unknown.
+ s = array_ops.placeholder(dtype=dtypes.int32)
+ v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
+ self.assertAllEqual(v2.get_shape().as_list(), [None])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index 782e6b5068..2ebf74a4d7 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@@ -327,7 +328,7 @@ class CholeskyBenchmark(test.Benchmark):
def benchmarkCholeskyOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
l = linalg_ops.cholesky(matrix)
@@ -341,7 +342,7 @@ class CholeskyBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/device:GPU:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
l = linalg_ops.cholesky(matrix)
@@ -359,7 +360,7 @@ class CholeskyBenchmark(test.Benchmark):
for shape in self.shapes:
matrix = self._GenerateMatrix(shape)
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device(device):
l = variables.Variable(np.linalg.cholesky(matrix))
grad_matrix = variables.Variable(
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 377c041675..a424a0f219 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -153,6 +153,7 @@ class CondV2Test(test.TestCase):
self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
def testDefunInCond(self):
+ self.skipTest("b/117293122")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -172,7 +173,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -198,7 +199,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -468,7 +469,6 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
- self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -502,29 +502,29 @@ class CondV2Test(test.TestCase):
return grads, pred_outer, pred_inner
- with ops.Graph().as_default():
+ with ops.Graph().as_default(), self.session(
+ graph=ops.get_default_graph()) as sess:
grads, pred_outer, pred_inner = build_graph()
- with self.session(graph=ops.get_default_graph()) as sess:
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: True
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: False
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: True
- }), [4., 2.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: False
- }), [5., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
def testSecondDerivative(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index c7e89dd5f9..baea5c0f6d 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -661,10 +660,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
- @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- graph = ops.Graph()
- with graph.as_default():
+ with self.cached_session():
x = constant_op.constant(10.0, name="x")
pred = math_ops.less(1, 2)
fn1 = lambda: array_ops.identity(x)
@@ -672,8 +669,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
grad = gradients_impl.gradients(r, [x])[0]
- with self.cached_session():
- self.assertAllEqual(1.0, grad.eval())
+ self.assertAllEqual(1.0, grad.eval())
def testCondGrad_2(self):
with self.cached_session():
@@ -3424,9 +3420,6 @@ class EagerTest(test.TestCase):
# TODO(b/117279927): Re-enable once msan failure is fixed.
def DISABLED_testCondInDefun(self):
- if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
- return unittest.skip("b/113346829 (gpu failure)")
-
with context.eager_mode():
@eager_function.defun
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index a52b2c0dc3..fb114f9f24 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -185,8 +186,8 @@ class MatrixDeterminantBenchmark(test.Benchmark):
def benchmarkMatrixDeterminantOp(self):
for shape in self.shapes:
- with ops.Graph().as_default(), session.Session() as sess, ops.device(
- "/cpu:0"):
+ with ops.Graph().as_default(), session.Session(
+ config=benchmark.benchmark_config()) as sess, ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
d = linalg_ops.matrix_determinant(matrix)
variables.global_variables_initializer().run()
@@ -198,8 +199,8 @@ class MatrixDeterminantBenchmark(test.Benchmark):
name="matrix_determinant_cpu_{shape}".format(shape=shape))
if test.is_gpu_available(True):
- with ops.Graph().as_default(), session.Session() as sess, ops.device(
- "/gpu:0"):
+ with ops.Graph().as_default(), session.Session(
+ config=benchmark.benchmark_config()) as sess, ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
d = linalg_ops.matrix_determinant(matrix)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 4beddd00bb..2f19ecc0e6 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -306,6 +306,19 @@ class PrintV2Test(test.TestCase):
logging_ops.print_v2(tensor)
self.assertTrue((expected + "\n") in printed.contents())
+ def testPrintsOrderedInDefun(self):
+ with context.eager_mode():
+
+ @function.defun
+ def prints():
+ logging_ops.print_v2("A")
+ logging_ops.print_v2("B")
+ logging_ops.print_v2("C")
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ prints()
+ self.assertTrue(("A\nB\nC\n") in printed.contents())
+
@test_util.run_in_graph_and_eager_modes()
def testPrintInDefunWithoutExplicitEvalOfPrint(self):
@function.defun
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index 68d626de2c..a0ef3a607e 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test as test_lib
@@ -109,7 +110,7 @@ class MatrixBandPartBenchmark(test_lib.Benchmark):
for shape_ in self.shapes:
for limits in (-1, -1), (-1, 0), (0, -1), (2, 2):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(array_ops.ones(shape_))
band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
@@ -123,7 +124,7 @@ class MatrixBandPartBenchmark(test_lib.Benchmark):
if test_lib.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = variables.Variable(array_ops.ones(shape_))
band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
index 0386e91276..9630c052b8 100644
--- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -181,7 +182,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
def benchmarkMatrixExponentialOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
expm = linalg_impl.matrix_exponential(matrix)
@@ -195,7 +196,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
expm = linalg_impl.matrix_exponential(matrix)
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index 720ba806e9..8bda04b53d 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -179,7 +180,7 @@ class MatrixInverseBenchmark(test.Benchmark):
for adjoint in False, True:
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint)
@@ -193,7 +194,7 @@ class MatrixInverseBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
index 723a15fbd1..3205e211d9 100644
--- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -159,7 +160,7 @@ class MatrixLogarithmBenchmark(test.Benchmark):
def benchmarkMatrixLogarithmOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
logm = gen_linalg_ops.matrix_logarithm(matrix)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index de495968a7..225a10e117 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test as test_lib
@@ -313,7 +314,7 @@ class MatrixSolveLsBenchmark(test_lib.Benchmark):
for num_rhs in 1, 2, matrix_shape[-1]:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
@@ -328,7 +329,7 @@ class MatrixSolveLsBenchmark(test_lib.Benchmark):
if run_gpu_test and (len(matrix_shape) < 3 or matrix_shape[0] < 513):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index b8f2736b7b..264df2565c 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -167,7 +168,7 @@ class MatrixSolveBenchmark(test.Benchmark):
for num_rhs in 1, 2, matrix_shape[-1]:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
@@ -185,7 +186,7 @@ class MatrixSolveBenchmark(test.Benchmark):
if run_gpu_test:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index a45a325b47..672d6556f5 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -282,6 +283,125 @@ class Relu6Test(test.TestCase):
self.assertLess(err, 1e-10)
+class LeakyReluTest(test.TestCase):
+
+ def _npLeakyRelu(self, np_features, alpha=0.1):
+ return np.maximum(np_features, alpha * np_features)
+
+ def testNpLeakyRelu(self):
+ self.assertAllClose(
+ np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
+ [0.1, -0.03, 0.5, -0.07, 0.9]]),
+ self._npLeakyRelu(
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]]),
+ alpha=0.1))
+
+ def _testLeakyRelu(self, np_features, alpha, use_gpu=False):
+ np_leaky_relu = self._npLeakyRelu(np_features, alpha)
+ with self.test_session(use_gpu=use_gpu):
+ leaky_relu = nn_ops.leaky_relu(np_features, alpha)
+ tf_leaky_relu = leaky_relu.eval()
+ self.assertAllClose(np_leaky_relu, tf_leaky_relu)
+ self.assertShapeEqual(np_leaky_relu, leaky_relu)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.2,
+ use_gpu=False)
+ if t in [np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.1,
+ use_gpu=True)
+
+ # The gradient test for Leaky ReLU is a bit tricky as the derivative is not
+ # well defined at around zero and we want to avoid that in terms of input
+ # values.
+ def testGradientFloat32(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradientFloat64(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradGradFloat32(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradGradFloat64(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradientScalar(self):
+ with self.test_session() as sess:
+ x = variables.Variable(-100.)
+ y = nn_ops.leaky_relu(x, 0.05)
+ loss = y**2
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2)
+ train_op = optimizer.minimize(loss)
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ self.assertAllClose(x.eval(), -99.9)
+
+
class EluTest(test.TestCase):
def _npElu(self, np_features):
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index 31e84341ae..fdfe1001b8 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
# pylint: disable=protected-access
@@ -192,7 +193,7 @@ class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
sorted(zip(indices_batch, indices_value)), dtype=np.int64)
values = ["feature_value_for_embedding_lookup"] * num_elements
shape = np.asarray([batch_size, num_elements], dtype=np.int64)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
with ops.device("/cpu:0"):
indices = variables.Variable(indices)
values = variables.Variable(values)
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 29fb002ef4..04ac589432 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -160,7 +161,7 @@ class WhereBenchmark(test.Benchmark):
x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p
v = resource_variable_ops.ResourceVariable(x)
op = array_ops.where(v)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
v.initializer.run()
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
gb_processed_input = m * n / 1.0e9
@@ -186,7 +187,7 @@ class WhereBenchmark(test.Benchmark):
y = resource_variable_ops.ResourceVariable(y_gen)
c = resource_variable_ops.ResourceVariable(c_gen)
op = array_ops.where(c, x, y)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
x.initializer.run()
y.initializer.run()
c.initializer.run()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 4be9c532f4..e3e4d5f910 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1407,8 +1407,13 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
gen_array_ops.conjugate_transpose
if (conjugate and a.dtype.is_complex) else gen_array_ops.transpose)
if perm is None:
- rank = gen_array_ops.rank(a)
- perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
+ a = ops.convert_to_tensor(a, name="a")
+ if not a.get_shape().ndims:
+ rank = gen_array_ops.rank(a)
+ perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
+ else:
+ rank = a.get_shape().ndims
+ perm = (rank - 1) - np.arange(rank)
ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index 195ad11c71..c9aa4d4889 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -282,9 +282,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
as is.
2. Tensors in the forward pass graph. These tensors may not be "live"
when the gradient is being computed. We replace such references by their
- corresponding tensor in the least common ancestor graph of `grad_graph` and
- `cond_graph`. Since we export intermediate tensors for all branch
- functions, this is always possible.
+ corresponding tensor in `cond_graph.outer_graph`. In the case of nested
+ control flow or functions, the gradient logic handling
+ `grad_graph.outer_graph` will make sure the tensor from
+ `cond_graph.outer_graph` is also correctly captured.
Args:
cond_graph: function.FuncGraph. The forward-pass function.
@@ -296,24 +297,23 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
new_inputs = []
for t in grad_graph.external_captures:
+ # `t` must either be in `grad_graph.outer_graph` or in the forward
+ # `cond_graph`.
if t.graph != grad_graph.outer_graph:
- # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
- # tensor to the least common ancestor of the `cond_graph` and
- # `grad_graph` so that it is "in-scope" for `grad_graph`.
- # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least
- # common ancestor once and re-use.
- assert _is_ancestor(cond_graph, t.graph)
- while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function.FuncGraph)
- if t in t.graph.internal_captures:
- # TODO(srbs): Consider building a map of internal_captures ->
- # external_captures instead of searching for `t` twice.
- t = t.graph.external_captures[t.graph.internal_captures.index(t)]
- else:
- # Note: All intermediate tensors are output by the If op.
- # TODO(srbs): .index() calls may be expensive. Optimize.
- t = t.graph._if.outputs[t.graph.outputs.index(t)]
- assert _is_ancestor(grad_graph, t.graph)
+ assert t.graph == cond_graph
+ # `internal_captures` are not treated as intermediates and hence not added
+ # to If op outputs. So we get the outer tensor corresponding to those
+ # from the list of `external_captures`.
+ try:
+ t = t.graph._if.outputs[t.graph.outputs.index(t)]
+ except ValueError:
+ index = t.graph.internal_captures.index(t)
+ t = t.graph.external_captures[index]
+
+ # Note: We rely on the capturing logic of the gradient If op graph to
+ # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
+ # and while_v2 handle this while building their gradient functions.
+ assert t.graph == cond_graph.outer_graph
new_inputs.append(t)
return new_inputs
@@ -492,11 +492,3 @@ def _get_output_shapes(true_graph_outputs, false_graph_outputs):
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
]
return output_shapes
-
-
-def _is_ancestor(graph, maybe_ancestor):
- if maybe_ancestor == graph:
- return True
- if isinstance(graph, _function.FuncGraph):
- return _is_ancestor(graph.outer_graph, maybe_ancestor)
- return False
diff --git a/tensorflow/python/ops/control_flow_ops_benchmark.py b/tensorflow/python/ops/control_flow_ops_benchmark.py
new file mode 100644
index 0000000000..9ba5ff2c0f
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops_benchmark.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for control flow ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class CondWithManyIntermediatesBenchmark(test.Benchmark):
+ """Checks the runtime performance of outputting all intermediates."""
+
+ NUM_INTERMEDIATES = 1000
+ NUM_ITERS = 500
+ NUM_WARM_UP_ITERS = 50
+
+ def _create_cond(self, x):
+
+ def branch_fn():
+ # Use a random value so the adds can't be constant folded.
+ return x + sum(random_ops.random_normal([])
+ for _ in range(self.NUM_INTERMEDIATES))
+
+ # Use a dynamic predicate to make sure the cond isn't constant folded.
+ return control_flow_ops.cond(math_ops.not_equal(x, -1),
+ branch_fn, lambda: 0.0)
+
+ def _benchmark_defun(self):
+ """Benchmarks cond in a defun."""
+
+ @function.defun
+ def cond_fn(x):
+ return self._create_cond(x)
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def _benchmark_graph(self):
+ """Benchmarks cond in legacy graph mode."""
+ with context.graph_mode():
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ cond_val = self._create_cond(x)
+
+ with session.Session() as sess:
+ cond_fn = sess.make_callable(cond_val, [x])
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def benchmark_cond_v1_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v1_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d7834ba350..bfe23834b7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+def copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes.resource or
+ target_t.dtype == dtypes.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
@@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs):
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
+ original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
+ for ot, t in zip(original_tensors, all_tensors):
+ copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index aac95037dc..6909fcaed5 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -800,23 +800,21 @@ def _GradientsHelper(ys,
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- # A grad_fn must be defined, either as a function or as None
- # for ops that do not have gradients.
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 35fdee4fad..ff86df6346 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -602,20 +602,19 @@ class AdjustHueBenchmark(test.Benchmark):
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
- with session.Session("", graph=ops.Graph(), config=config) as sess:
- with ops.device(device):
- inputs = variables.Variable(
- random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
- trainable=False,
- dtype=dtypes.float32)
- delta = constant_op.constant(0.1, dtype=dtypes.float32)
- outputs = image_ops.adjust_hue(inputs, delta)
- run_op = control_flow_ops.group(outputs)
- sess.run(variables.global_variables_initializer())
- for i in xrange(warmup_rounds + benchmark_rounds):
- if i == warmup_rounds:
- start = time.time()
- sess.run(run_op)
+ with self.benchmark_session(config=config, device=device) as sess:
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = image_ops.adjust_hue(inputs, delta)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
@@ -646,21 +645,20 @@ class AdjustSaturationBenchmark(test.Benchmark):
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
- with session.Session("", graph=ops.Graph(), config=config) as sess:
- with ops.device(device):
- inputs = variables.Variable(
- random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
- trainable=False,
- dtype=dtypes.float32)
- delta = constant_op.constant(0.1, dtype=dtypes.float32)
- outputs = image_ops.adjust_saturation(inputs, delta)
- run_op = control_flow_ops.group(outputs)
- sess.run(variables.global_variables_initializer())
- for _ in xrange(warmup_rounds):
- sess.run(run_op)
- start = time.time()
- for _ in xrange(benchmark_rounds):
- sess.run(run_op)
+ with self.benchmark_session(config=config, device=device) as sess:
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = image_ops.adjust_saturation(inputs, delta)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for _ in xrange(warmup_rounds):
+ sess.run(run_op)
+ start = time.time()
+ for _ in xrange(benchmark_rounds):
+ sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
@@ -699,7 +697,7 @@ class ResizeBilinearBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
@@ -747,7 +745,7 @@ class ResizeBicubicBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
@@ -804,7 +802,7 @@ class ResizeAreaBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index e1a01ab4c3..902653befc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad):
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+@ops.RegisterGradient("LeakyRelu")
+def _LeakyReluGrad(op, grad):
+ x = op.inputs[0]
+ alpha = op.get_attr("alpha")
+ return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
+
+
+@ops.RegisterGradient("LeakyReluGrad")
+def _LeakyReluGradGrad(op, grad):
+ x = op.inputs[1]
+ alpha = op.get_attr("alpha")
+ return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops.elu_grad(grad, op.outputs[0])
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 1fbe31a098..04962da7f7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,6 +22,7 @@ import numbers
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1602,6 +1603,8 @@ def leaky_relu(features, alpha=0.2, name=None):
features = ops.convert_to_tensor(features, name="features")
if features.dtype.is_integer:
features = math_ops.to_float(features)
+ if compat.forward_compatible(2018, 11, 1):
+ return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
return math_ops.maximum(alpha * features, features, name=name)
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index ff50fe0d09..a2da6412ed 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -217,21 +217,21 @@ def _features_to_raw_params(features, types):
feature = features[key]
if isinstance(feature, VarLenFeature):
if VarLenFeature not in types:
- raise ValueError("Unsupported VarLenFeature %s." % feature)
+ raise ValueError("Unsupported VarLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
sparse_keys.append(key)
sparse_types.append(feature.dtype)
elif isinstance(feature, SparseFeature):
if SparseFeature not in types:
- raise ValueError("Unsupported SparseFeature %s." % feature)
+ raise ValueError("Unsupported SparseFeature %s." % (feature,))
if not feature.index_key:
raise ValueError(
- "Missing index_key for SparseFeature %s." % feature)
+ "Missing index_key for SparseFeature %s." % (feature,))
if not feature.value_key:
raise ValueError(
- "Missing value_key for SparseFeature %s." % feature)
+ "Missing value_key for SparseFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
index_keys = feature.index_key
@@ -260,7 +260,7 @@ def _features_to_raw_params(features, types):
sparse_types.append(feature.dtype)
elif isinstance(feature, FixedLenFeature):
if FixedLenFeature not in types:
- raise ValueError("Unsupported FixedLenFeature %s." % feature)
+ raise ValueError("Unsupported FixedLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
@@ -281,7 +281,8 @@ def _features_to_raw_params(features, types):
dense_defaults[key] = feature.default_value
elif isinstance(feature, FixedLenSequenceFeature):
if FixedLenSequenceFeature not in types:
- raise ValueError("Unsupported FixedLenSequenceFeature %s." % feature)
+ raise ValueError("Unsupported FixedLenSequenceFeature %s." % (
+ feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8e88a84d60..0419656143 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
- function._copy_handle_data(src_t, tgt_t)
+ custom_gradient.copy_handle_data(src_t, tgt_t)
# TODO(srbs): Move to common utils for cond_v2 and while_v2.
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index fa17b17d10..4f7abb311a 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -27,6 +27,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import timeline
from tensorflow.python.platform import app
@@ -182,6 +183,19 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
throughput=throughput, extras=extras)
+@tf_export("test.benchmark_config")
+def benchmark_config():
+ """Returns a tf.ConfigProto for disabling the dependency optimizer.
+
+ Returns:
+ A TensorFlow ConfigProto object.
+ """
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.dependency_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+
@tf_export("test.Benchmark")
class TensorFlowBenchmark(Benchmark):
"""Abstract class that provides helpers for TensorFlow benchmarks."""
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 61e0abbfcb..adbce95c6f 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -209,6 +209,7 @@ limitations under the License.
SWIG_fail;
} else {
int num_outputs = $1->size();
+ Py_CLEAR($result);
$result = PyList_New(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
PyObject *output;
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 82f0e3be52..a479f38165 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -195,8 +195,12 @@ class Scaffold(object):
default_ready_op)
if self._ready_for_local_init_op is None:
def default_ready_for_local_init_op():
- return variables.report_uninitialized_variables(
- variables.global_variables())
+ return array_ops.concat([
+ variables.report_uninitialized_variables(
+ variables.global_variables()),
+ resources.report_uninitialized_resources(
+ resources.shared_resources())
+ ], 0)
self._ready_for_local_init_op = Scaffold.get_or_default(
'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
default_ready_for_local_init_op)
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 041266da3e..89bfcaf4ad 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.util.tf_export import tf_export
@@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The moving average of 'variable' updated with 'value' is:
variable * decay + value * (1 - decay)
- The returned Operation sets 'variable' to the newly computed moving average.
-
- The new value of 'variable' can be set with the 'AssignSub' op as:
+ The returned Operation sets 'variable' to the newly computed moving average,
+ by performing this subtraction:
variable -= (1 - decay) * (variable - value)
Since variables that are initialized to a `0` value will be `0` biased,
@@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The names of the debias shadow variables, by default, include both the scope
they were created in and the scope of the variables they debias. They are also
- given a uniqifying-suffix.
+ given a uniquifying-suffix.
E.g.:
@@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
with tf.variable_scope('scope1'):
with tf.variable_scope('scope2'):
var = tf.get_variable('foo')
- tf.assign_moving_average(var, 0.0, 1.0)
- tf.assign_moving_average(var, 0.0, 0.9)
+ update_1 = tf.assign_moving_average(var, 0.0, 1.0)
+ update_2 = tf.assign_moving_average(var, 0.0, 0.9)
# var.name: 'scope1/scope2/foo'
# shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
@@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
name: Optional name of the returned operation.
Returns:
- A reference to the input 'variable' tensor with the newly computed
- moving average.
+ A tensor which if evaluated will compute and return the new moving average.
"""
+ def update_fn(v, value, decay=decay):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != v.dtype.base_dtype:
+ decay = math_ops.cast(decay, v.dtype.base_dtype)
+ if zero_debias:
+ update_delta = _zero_debias(v, value, decay)
+ else:
+ update_delta = (v - value) * decay
+ return state_ops.assign_sub(v, update_delta, name=scope)
+
with ops.name_scope(name, "AssignMovingAvg",
[variable, value, decay]) as scope:
- with ops.colocate_with(variable):
- decay = ops.convert_to_tensor(1.0 - decay, name="decay")
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- if zero_debias:
- update_delta = _zero_debias(variable, value, decay)
- else:
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ # In a tower context, we update variable using the mean of value across
+ # towers.
+ def merge_fn(strategy, v, value):
+ value = strategy.reduce(
+ variable_scope.VariableAggregation.MEAN, value, v)
+ return strategy.update(v, update_fn, value)
+
+ return tower_context.merge_call(merge_fn, variable, value)
+ else:
+ strategy = distribution_strategy_context.get_cross_tower_context()
+ return strategy.update(variable, update_fn, value)
def weighted_moving_average(value,
@@ -379,8 +392,6 @@ class ExponentialMovingAverage(object):
Raises:
TypeError: If the arguments are not an allowed type.
- ValueError: If the moving average of one of the variables is already
- being computed.
"""
# TODO(touts): op_scope
if var_list is None:
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index a0e6bf65cf..3a3af4bffa 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -63,6 +63,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import difflib
import six
@@ -101,10 +102,19 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva
if normalize_numbers:
NormalizeNumberFields(pb)
- self.assertMultiLineEqual(
- text_format.MessageToString(a, descriptor_pool=pool),
- text_format.MessageToString(b, descriptor_pool=pool),
- msg=msg)
+ a_str = text_format.MessageToString(a, descriptor_pool=pool)
+ b_str = text_format.MessageToString(b, descriptor_pool=pool)
+
+ # Some Python versions would perform regular diff instead of multi-line
+ # diff if string is longer than 2**16. We substitute this behavior
+ # with a call to unified_diff instead to have easier-to-read diffs.
+ # For context, see: https://bugs.python.org/issue11763.
+ if len(a_str) < 2**16 and len(b_str) < 2**16:
+ self.assertMultiLineEqual(a_str, b_str, msg=msg)
+ else:
+ diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
+ b_str.splitlines(True)))
+ self.fail('%s : %s' % (msg, diff))
def NormalizeNumberFields(pb):
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 7b3e618e84..11eb9ce947 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -825,18 +825,16 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
- PyObject* f1 = PyObject_GetAttrString(o1, "_fields");
- PyObject* f2 = PyObject_GetAttrString(o2, "_fields");
+ Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
+ Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
if (f1 == nullptr || f2 == nullptr) {
- Py_XDECREF(f1);
- Py_XDECREF(f2);
PyErr_SetString(
PyExc_RuntimeError,
"Expected namedtuple-like objects (that have _fields attr)");
return nullptr;
}
- if (PyObject_RichCompareBool(f1, f2, Py_NE)) {
+ if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
Py_RETURN_FALSE;
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
index 2e9de9ebb2..eb315e356d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
}
member_method {
+ name: "exponential"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "get"
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
index a71a59e269..9feb7c09b8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
@@ -46,7 +46,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
- argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+ argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'axis\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.001\'], "
}
member_method {
name: "batch_set_value"
@@ -98,7 +98,7 @@ tf_module {
}
member_method {
name: "conv2d_transpose"
- argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\'], "
+ argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], "
}
member_method {
name: "conv3d"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index c3dd2ad046..014f5828fa 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index c440604aae..a6e4856de9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 065bb4d35b..381839d6de 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index c7ba6056f9..2933f9f4b3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 8f4f7918ab..9c9c7461c8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 93c442bd55..44ca598724 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 5ea61d118d..a8094c0bde 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 11dca17c6d..3ebe162f57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 278429af6f..c0a53b847b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 935a69ab2f..ff6c6f3ec4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 238d96cca6..d26da270e7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 4a45bf7997..524c5fd69e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
index 81b91d2780..138d97b11f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
@@ -70,6 +70,6 @@ tf_module {
}
member_method {
name: "to_categorical"
- argspec: "args=[\'y\', \'num_classes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt
index abe9b068ae..984c584c9e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt
@@ -21,6 +21,10 @@ tf_module {
argspec: "args=[\'actual\', \'expected\', \'checkpoint_v2\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
+ name: "benchmark_config"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "compute_gradient"
argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
index 2e9de9ebb2..eb315e356d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
}
member_method {
+ name: "exponential"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "get"
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
index a71a59e269..9feb7c09b8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
@@ -46,7 +46,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
- argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+ argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'axis\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.001\'], "
}
member_method {
name: "batch_set_value"
@@ -98,7 +98,7 @@ tf_module {
}
member_method {
name: "conv2d_transpose"
- argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\'], "
+ argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], "
}
member_method {
name: "conv3d"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index c3dd2ad046..014f5828fa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index c440604aae..a6e4856de9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 065bb4d35b..381839d6de 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index c7ba6056f9..2933f9f4b3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 8f4f7918ab..9c9c7461c8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 93c442bd55..44ca598724 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 5ea61d118d..a8094c0bde 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 11dca17c6d..3ebe162f57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 278429af6f..c0a53b847b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 935a69ab2f..ff6c6f3ec4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 238d96cca6..d26da270e7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 4a45bf7997..524c5fd69e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
index 81b91d2780..138d97b11f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
@@ -70,6 +70,6 @@ tf_module {
}
member_method {
name: "to_categorical"
- argspec: "args=[\'y\', \'num_classes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
index abe9b068ae..984c584c9e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
@@ -21,6 +21,10 @@ tf_module {
argspec: "args=[\'actual\', \'expected\', \'checkpoint_v2\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
+ name: "benchmark_config"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "compute_gradient"
argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/ci_build/Dockerfile.android b/tensorflow/tools/ci_build/Dockerfile.android
index dcf077791a..7e72eb0cbf 100644
--- a/tensorflow/tools/ci_build/Dockerfile.android
+++ b/tensorflow/tools/ci_build/Dockerfile.android
@@ -45,9 +45,14 @@ ENV ANDROID_NDK_FILENAME android-ndk-r14b-linux-x86_64.zip
ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME}
ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk
ENV PATH ${PATH}:${ANDROID_NDK_HOME}
+# Workaround for b/117156972: inject missing #include into NDK versions of
+# futex.h.
RUN cd ${ANDROID_DEV_HOME} && \
wget -q ${ANDROID_NDK_URL} && \
unzip ${ANDROID_NDK_FILENAME} -d ${ANDROID_DEV_HOME} && \
+ sed -i 15i"#include <linux/compiler.h>" ${ANDROID_DEV_HOME}/android-ndk-r14b/platforms/android-14/arch-arm/usr/include/linux/futex.h && \
+ sed -i 15i"#include <linux/compiler.h>" ${ANDROID_DEV_HOME}/android-ndk-r14b/platforms/android-14/arch-mips/usr/include/linux/futex.h && \
+ sed -i 15i"#include <linux/compiler.h>" ${ANDROID_DEV_HOME}/android-ndk-r14b/platforms/android-14/arch-x86/usr/include/linux/futex.h && \
rm ${ANDROID_NDK_FILENAME} && \
bash -c "ln -s ${ANDROID_DEV_HOME}/android-ndk-* ${ANDROID_NDK_HOME}"
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index fdff867ff0..489722c0e9 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -423,7 +423,7 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] ||
[[ ${CTYPE} == "debian.jessie.cpu" ]]; then
# CPU only command, fully parallel.
NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} "\
- "${EXTRA_ARGS} -- ${BAZEL_TARGET}"
+"${EXTRA_ARGS} -- ${BAZEL_TARGET}"
elif [[ ${CTYPE} == gpu* ]]; then
# GPU only command, run as many jobs as the GPU count only.
NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index d864a7a039..dd1dca9ee8 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -226,13 +226,14 @@ if os.name == 'nt':
else:
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.so'
-headers = (list(find_files('*.h', 'tensorflow/core')) +
- list(find_files('*.h', 'tensorflow/stream_executor')) +
- list(find_files('*.h', 'google/protobuf_archive/src')) +
- list(find_files('*', 'third_party/eigen3')) +
- list(find_files('*.h',
- 'tensorflow/include/external/com_google_absl')) +
- list(find_files('*', 'tensorflow/include/external/eigen_archive')))
+headers = (
+ list(find_files('*.h', 'tensorflow/core')) + list(
+ find_files('*.h', 'tensorflow/stream_executor')) +
+ list(find_files('*.h', 'google/protobuf_archive/src')) + list(
+ find_files('*', 'third_party/eigen3')) + list(
+ find_files('*.h', 'tensorflow/include/external/com_google_absl')) +
+ list(find_files('*.inc', 'tensorflow/include/external/com_google_absl')) +
+ list(find_files('*', 'tensorflow/include/external/eigen_archive')))
setup(
name=project_name,
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8df41f96b8..adeac62e43 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -22,10 +22,14 @@ load(
)
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
load("//third_party/icu:workspace.bzl", icu = "repo")
+load("//third_party/jpeg:workspace.bzl", jpeg = "repo")
+load("//third_party/nasm:workspace.bzl", nasm = "repo")
def initialize_third_party():
flatbuffers()
icu()
+ jpeg()
+ nasm()
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
@@ -110,11 +114,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
- sha256 = "507903ef9353cb25cccd0a6840048fdd348fd20e98314d694f04a990c0f277e3",
- strip_prefix = "abseil-cpp-f21d187b80e3b7f08fb279775ea9c8b48c636030",
+ sha256 = "f186bf5d9fce3037c602a21f86facbdd317adecef36e1726ec7bc7b496943a82",
+ strip_prefix = "abseil-cpp-e821380d69a549dc64900693942789d21aa4df5e",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/f21d187b80e3b7f08fb279775ea9c8b48c636030.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/f21d187b80e3b7f08fb279775ea9c8b48c636030.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e821380d69a549dc64900693942789d21aa4df5e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/e821380d69a549dc64900693942789d21aa4df5e.tar.gz",
],
)
@@ -234,31 +238,6 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
- name = "nasm",
- build_file = clean_dep("//third_party:nasm.BUILD"),
- sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
- strip_prefix = "nasm-2.13.03",
- system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
- urls = [
- "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
- "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
- "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
- ],
- )
-
- tf_http_archive(
- name = "jpeg",
- build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
- sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
- strip_prefix = "libjpeg-turbo-2.0.0",
- system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
- urls = [
- "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
- "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
- ],
- )
-
- tf_http_archive(
name = "png_archive",
build_file = clean_dep("//third_party:png.BUILD"),
patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
diff --git a/third_party/jpeg/BUILD b/third_party/jpeg/BUILD
index 5b01f6e3e4..e3aec1fce9 100644
--- a/third_party/jpeg/BUILD
+++ b/third_party/jpeg/BUILD
@@ -1 +1 @@
-licenses(["notice"])
+# Needed to make this a package.
diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/BUILD.bazel
index 1b9b9bf2f5..5243e995a3 100644
--- a/third_party/jpeg/jpeg.BUILD
+++ b/third_party/jpeg/BUILD.bazel
@@ -162,9 +162,9 @@ cc_library(
hdrs = [
"simd/powerpc/jccolext-altivec.c",
"simd/powerpc/jcgryext-altivec.c",
+ "simd/powerpc/jcsample.h",
"simd/powerpc/jdcolext-altivec.c",
"simd/powerpc/jdmrgext-altivec.c",
- "simd/powerpc/jcsample.h",
"simd/powerpc/jsimd_altivec.h",
],
copts = libjpegturbo_copts,
@@ -186,7 +186,6 @@ cc_library(
"jsimd.h",
"jsimddct.h",
"simd/jsimd.h",
- "simd/x86_64/jsimd.c",
"simd/x86_64/jccolor-avx2.o",
"simd/x86_64/jccolor-sse2.o",
"simd/x86_64/jcgray-avx2.o",
@@ -213,6 +212,7 @@ cc_library(
"simd/x86_64/jquantf-sse2.o",
"simd/x86_64/jquanti-avx2.o",
"simd/x86_64/jquanti-sse2.o",
+ "simd/x86_64/jsimd.c",
"simd/x86_64/jsimdcpu.o",
],
copts = libjpegturbo_copts,
@@ -322,9 +322,9 @@ cc_library(
"jpeglib.h",
"jsimd.h",
"jsimddct.h",
- "simd/jsimd.h",
"simd/arm/jsimd.c",
"simd/arm/jsimd_neon.S",
+ "simd/jsimd.h",
],
copts = libjpegturbo_copts,
nocopts = libjpegturbo_nocopts,
@@ -343,9 +343,9 @@ cc_library(
"jpeglib.h",
"jsimd.h",
"jsimddct.h",
- "simd/jsimd.h",
"simd/arm64/jsimd.c",
"simd/arm64/jsimd_neon.S",
+ "simd/jsimd.h",
],
copts = libjpegturbo_copts,
nocopts = libjpegturbo_nocopts,
@@ -366,7 +366,6 @@ cc_library(
"jsimd.h",
"jsimddct.h",
"simd/jsimd.h",
- "simd/x86_64/jsimd.c",
"simd/x86_64/jccolor-avx2.obj",
"simd/x86_64/jccolor-sse2.obj",
"simd/x86_64/jcgray-avx2.obj",
@@ -393,6 +392,7 @@ cc_library(
"simd/x86_64/jquantf-sse2.obj",
"simd/x86_64/jquanti-avx2.obj",
"simd/x86_64/jquanti-sse2.obj",
+ "simd/x86_64/jsimd.c",
"simd/x86_64/jsimdcpu.obj",
],
copts = libjpegturbo_copts,
@@ -603,6 +603,7 @@ JCONFIGINT_WIN_SUBSTITUTIONS = {
}
JCONFIGINT_NOWIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS)
+
JCONFIGINT_WIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS)
template_rule(
diff --git a/third_party/systemlibs/jpeg.BUILD b/third_party/jpeg/BUILD.system
index f4f52da9bd..f4f52da9bd 100644
--- a/third_party/systemlibs/jpeg.BUILD
+++ b/third_party/jpeg/BUILD.system
diff --git a/third_party/jpeg/jpeg_helpers.BUILD.bazel b/third_party/jpeg/jpeg_helpers.BUILD.bazel
new file mode 100644
index 0000000000..5b01f6e3e4
--- /dev/null
+++ b/third_party/jpeg/jpeg_helpers.BUILD.bazel
@@ -0,0 +1 @@
+licenses(["notice"])
diff --git a/third_party/jpeg/workspace.bzl b/third_party/jpeg/workspace.bzl
new file mode 100644
index 0000000000..2bb7dacd32
--- /dev/null
+++ b/third_party/jpeg/workspace.bzl
@@ -0,0 +1,16 @@
+"""loads the jpeg library, used by TF."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "jpeg",
+ urls = [
+ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
+ "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
+ ],
+ sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+ strip_prefix = "libjpeg-turbo-2.0.0",
+ build_file = "//third_party/jpeg:BUILD.bazel",
+ system_build_file = "//third_party/jpeg:BUILD.system",
+ )
diff --git a/third_party/nasm/BUILD b/third_party/nasm/BUILD
new file mode 100644
index 0000000000..e3aec1fce9
--- /dev/null
+++ b/third_party/nasm/BUILD
@@ -0,0 +1 @@
+# Needed to make this a package.
diff --git a/third_party/nasm.BUILD b/third_party/nasm/BUILD.bazel
index d746a65e7e..c68d713946 100644
--- a/third_party/nasm.BUILD
+++ b/third_party/nasm/BUILD.bazel
@@ -137,12 +137,6 @@ cc_binary(
":windows": ["config/msvc.h"],
"//conditions:default": [],
}),
- includes = [
- "asm",
- "include",
- "output",
- "x86",
- ],
copts = select({
":windows": [],
"//conditions:default": [
@@ -157,6 +151,12 @@ cc_binary(
"HAVE_SYS_TYPES_H",
],
}),
+ includes = [
+ "asm",
+ "include",
+ "output",
+ "x86",
+ ],
visibility = ["@jpeg//:__pkg__"],
)
diff --git a/third_party/systemlibs/nasm.BUILD b/third_party/nasm/BUILD.system
index 10ef8d8832..10ef8d8832 100644
--- a/third_party/systemlibs/nasm.BUILD
+++ b/third_party/nasm/BUILD.system
diff --git a/third_party/nasm/workspace.bzl b/third_party/nasm/workspace.bzl
new file mode 100644
index 0000000000..6d50f6fcad
--- /dev/null
+++ b/third_party/nasm/workspace.bzl
@@ -0,0 +1,17 @@
+"""loads the nasm library, used by TF."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "nasm",
+ urls = [
+ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
+ "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ ],
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
+ build_file = "//third_party/nasm:BUILD.bazel",
+ system_build_file = "//third_party/nasm:BUILD.system",
+ )