aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-08-06 13:34:24 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-08-06 13:34:24 -0700
commit1149ad359f4e51a4e2c37a1dab8112056a38ef9b (patch)
tree5b5e99bc5e3c6d6859620d45a595bbaaec9d0f9d /tensorflow
parent46c2eafd65fd55d0837a9a86e8843f7f6d615990 (diff)
parent9e35139d16bf259794a23e60c9f2b3f4e38c3b48 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/c/eager/c_api.cc75
-rw-r--r--tensorflow/c/eager/c_api.h21
-rw-r--r--tensorflow/c/eager/c_api_internal.h18
-rw-r--r--tensorflow/c/eager/c_api_test.cc201
-rw-r--r--tensorflow/cc/BUILD31
-rw-r--r--tensorflow/cc/client/client_session.cc18
-rw-r--r--tensorflow/cc/client/client_session.h28
-rw-r--r--tensorflow/cc/client/client_session_test.cc21
-rw-r--r--tensorflow/cc/framework/gradient_checker.cc12
-rw-r--r--tensorflow/cc/framework/gradient_checker_test.cc16
-rw-r--r--tensorflow/cc/gradients/image_grad.cc74
-rw-r--r--tensorflow/cc/gradients/image_grad_test.cc157
-rw-r--r--tensorflow/compiler/aot/BUILD25
-rw-r--r--tensorflow/compiler/aot/codegen.cc6
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl652
-rw-r--r--tensorflow/compiler/jit/BUILD31
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc28
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h32
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc24
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc1
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc18
-rw-r--r--tensorflow/compiler/tests/eager_test.py15
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc96
-rw-r--r--tensorflow/compiler/tf2xla/BUILD25
-rw-r--r--tensorflow/compiler/tf2xla/cpu_function_runtime.cc (renamed from tensorflow/compiler/aot/runtime.cc)30
-rw-r--r--tensorflow/compiler/tf2xla/cpu_function_runtime.h (renamed from tensorflow/compiler/aot/runtime.h)32
-rw-r--r--tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc (renamed from tensorflow/compiler/aot/runtime_test.cc)45
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc32
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h29
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc12
-rw-r--r--tensorflow/compiler/xla/client/lib/prng.cc2
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc15
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc13
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc35
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h14
-rw-r--r--tensorflow/compiler/xla/layout_util.cc2
-rw-r--r--tensorflow/compiler/xla/literal_util.cc1
-rw-r--r--tensorflow/compiler/xla/metric_table_report.cc7
-rw-r--r--tensorflow/compiler/xla/python_api/BUILD2
-rw-r--r--tensorflow/compiler/xla/python_api/types.py7
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc10
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc104
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h27
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc29
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc414
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h98
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc44
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul.cc24
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc40
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc130
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc52
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc61
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h42
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc49
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc98
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h41
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc70
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc100
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_fix.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc49
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.cc53
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/service.cc1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc389
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h10
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc731
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc15
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc9
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc5
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc2
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc8
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc8
-rw-r--r--tensorflow/compiler/xla/xla_data.proto14
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees.py2
-rw-r--r--tensorflow/contrib/autograph/converters/directives.py22
-rw-r--r--tensorflow/contrib/autograph/converters/directives_test.py19
-rw-r--r--tensorflow/contrib/autograph/converters/error_handlers_test.py6
-rw-r--r--tensorflow/contrib/autograph/core/converter.py2
-rw-r--r--tensorflow/contrib/autograph/core/errors_test.py3
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/keras_test.py41
-rw-r--r--tensorflow/contrib/autograph/impl/api.py9
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py21
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py13
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py2
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py4
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py19
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info_test.py3
-rw-r--r--tensorflow/contrib/bigtable/README.md7
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py40
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD2
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py48
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py134
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py216
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py16
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py17
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py3
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py3
-rw-r--r--tensorflow/contrib/cmake/external/eigen.cmake7
-rw-r--r--tensorflow/contrib/cmake/external/highwayhash.cmake36
-rw-r--r--tensorflow/contrib/cmake/external/nsync.cmake42
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/coder/BUILD44
-rw-r--r--tensorflow/contrib/coder/README.md73
-rw-r--r--tensorflow/contrib/coder/__init__.py3
-rw-r--r--tensorflow/contrib/coder/python/layers/entropybottleneck.py697
-rw-r--r--tensorflow/contrib/coder/python/layers/entropybottleneck_test.py315
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD16
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py31
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py36
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py29
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py189
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py27
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py44
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py43
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py3
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py30
-rw-r--r--tensorflow/contrib/distribute/BUILD3
-rw-r--r--tensorflow/contrib/distribute/__init__.py6
-rw-r--r--tensorflow/contrib/distribute/python/BUILD65
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py205
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py217
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py131
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py172
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py151
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py375
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py138
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py73
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py115
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py55
-rw-r--r--tensorflow/contrib/eager/python/datasets.py32
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py14
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py3
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/BUILD59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/config.py72
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops.py71
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops_test.py59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan.py232
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan_test.py101
-rw-r--r--tensorflow/contrib/eager/python/examples/spinn/spinn_test.py4
-rw-r--r--tensorflow/contrib/framework/__init__.py1
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils.py4
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py14
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py5
-rw-r--r--tensorflow/contrib/lite/BUILD16
-rw-r--r--tensorflow/contrib/lite/Makefile13
-rw-r--r--tensorflow/contrib/lite/allocation.cc45
-rw-r--r--tensorflow/contrib/lite/allocation.h2
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD49
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.cc102
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h57
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc150
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel_test.cc197
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.cc154
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.h97
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h8
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity237
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs27
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset3
-rw-r--r--tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md3
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/BUILD84
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h150
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h79
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h420
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc247
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc238
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h114
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h50
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/top_n.h341
-rw-r--r--tensorflow/contrib/lite/interpreter.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD30
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc55
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc44
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD3
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h34
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h42
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/register.h4
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc5
-rw-r--r--tensorflow/contrib/lite/mmap_allocation.cc61
-rw-r--r--tensorflow/contrib/lite/mmap_allocation_disabled.cc39
-rw-r--r--tensorflow/contrib/lite/model.cc11
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor.h4
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc4
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h8
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate_disabled.cc42
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs10
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h236
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py20
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc3
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc123
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h26
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc29
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc14
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.cc25
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rwxr-xr-xtensorflow/contrib/makefile/download_dependencies.sh6
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py13
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor.py5
-rw-r--r--tensorflow/contrib/predictor/predictor_factories.py10
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py21
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py55
-rw-r--r--tensorflow/contrib/tensorrt/BUILD21
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc577
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc52
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h40
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc8
-rw-r--r--tensorflow/contrib/tensorrt/convert/utils.cc34
-rw-r--r--tensorflow/contrib/tensorrt/convert/utils.h11
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc29
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h2
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py4
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py90
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc47
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc4
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py252
-rw-r--r--tensorflow/contrib/tensorrt/test/batch_matmul_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py5
-rw-r--r--tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py19
-rw-r--r--tensorflow/contrib/tensorrt/test/concatenation_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/const_broadcast_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py348
-rw-r--r--tensorflow/contrib/tensorrt/test/unary_test.py5
-rw-r--r--tensorflow/contrib/tensorrt/test/utils.cc101
-rw-r--r--tensorflow/contrib/tensorrt/test/utils.h44
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/vgg_block_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i114
-rw-r--r--tensorflow/contrib/timeseries/__init__.py3
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/__init__.py1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py168
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py81
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py60
-rw-r--r--tensorflow/contrib/tpu/BUILD19
-rw-r--r--tensorflow/contrib/tpu/__init__.py5
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc172
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py13
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py9
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py3
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py4
-rw-r--r--tensorflow/contrib/training/python/training/training_test.py3
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_OptionalFromValue.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_OptionalGetValue.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_OptionalHasValue.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_OptionalNone.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IteratorGetNextAsOptional.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_OptionalFromValue.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_OptionalGetValue.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_OptionalHasValue.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_OptionalNone.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc26
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc100
-rw-r--r--tensorflow/core/common_runtime/eager/context.h68
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc269
-rw-r--r--tensorflow/core/common_runtime/executor.cc46
-rw-r--r--tensorflow/core/common_runtime/placer.cc4
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc8
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc3
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h2
-rw-r--r--tensorflow/core/framework/op_kernel.cc13
-rw-r--r--tensorflow/core/framework/op_kernel.h6
-rw-r--r--tensorflow/core/framework/step_stats.proto5
-rw-r--r--tensorflow/core/framework/tensor.cc4
-rw-r--r--tensorflow/core/framework/tensor_testutil_test.cc7
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc178
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h2
-rw-r--r--tensorflow/core/grappler/graph_view.cc10
-rw-r--r--tensorflow/core/grappler/graph_view.h2
-rw-r--r--tensorflow/core/grappler/mutable_graph_view.cc20
-rw-r--r--tensorflow/core/grappler/mutable_graph_view.h7
-rw-r--r--tensorflow/core/grappler/mutable_graph_view_test.cc67
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD35
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc41
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h7
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc28
-rw-r--r--tensorflow/core/grappler/optimizers/data/latency_all_edges.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/latency_all_edges.h46
-rw-r--r--tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc92
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer_test.cc23
-rw-r--r--tensorflow/core/grappler/utils.h3
-rw-r--r--tensorflow/core/grappler/utils/functions.cc2
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc9
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.h3
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/data/BUILD16
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc392
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc85
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc270
-rw-r--r--tensorflow/core/kernels/data/optional_ops.h36
-rw-r--r--tensorflow/core/kernels/functional_ops.cc73
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op.h16
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op_test.cc16
-rw-r--r--tensorflow/core/kernels/spacetobatch_op.cc113
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.cc10
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt120
-rw-r--r--tensorflow/core/ops/dataset_ops.cc29
-rw-r--r--tensorflow/core/ops/functional_ops.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt83
-rw-r--r--tensorflow/core/platform/cloud/BUILD1
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc110
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h37
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc1413
-rw-r--r--tensorflow/core/platform/env.h5
-rw-r--r--tensorflow/core/platform/s3/s3_crypto.cc113
-rw-r--r--tensorflow/core/platform/s3/s3_crypto.h35
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc4
-rw-r--r--tensorflow/core/protobuf/worker.proto5
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_entry.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_scorer.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h1
-rw-r--r--tensorflow/core/util/ctc/ctc_decoder.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_loss_util.h2
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc772
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h11
-rw-r--r--tensorflow/core/util/mkl_util.h17
-rw-r--r--tensorflow/docs_src/guide/custom_estimators.md4
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md22
-rw-r--r--tensorflow/docs_src/install/install_linux.md18
-rw-r--r--tensorflow/docs_src/install/install_mac.md10
-rw-r--r--tensorflow/docs_src/install/install_sources.md9
-rw-r--r--tensorflow/docs_src/performance/xla/jit.md12
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md196
-rw-r--r--tensorflow/docs_src/performance/xla/tfcompile.md5
-rw-r--r--tensorflow/go/op/wrappers.go1708
-rw-r--r--tensorflow/java/BUILD13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java32
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Session.java18
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java15
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java513
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java68
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java107
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java165
-rw-r--r--tensorflow/python/BUILD102
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD23
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py96
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py12
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py186
-rw-r--r--tensorflow/python/data/ops/BUILD21
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py41
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py67
-rw-r--r--tensorflow/python/data/ops/optional_ops.py209
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py48
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py110
-rw-r--r--tensorflow/python/eager/backprop.py9
-rw-r--r--tensorflow/python/eager/context.py82
-rw-r--r--tensorflow/python/eager/function.py62
-rw-r--r--tensorflow/python/eager/function_test.py186
-rw-r--r--tensorflow/python/eager/graph_callable.py2
-rw-r--r--tensorflow/python/estimator/estimator.py33
-rw-r--r--tensorflow/python/estimator/estimator_test.py4
-rw-r--r--tensorflow/python/estimator/export/export.py75
-rw-r--r--tensorflow/python/estimator/export/export_test.py6
-rw-r--r--tensorflow/python/estimator/keras.py3
-rw-r--r--tensorflow/python/framework/error_interpolation.py96
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py15
-rw-r--r--tensorflow/python/framework/function.py10
-rw-r--r--tensorflow/python/framework/ops.py50
-rw-r--r--tensorflow/python/framework/ops_test.py22
-rw-r--r--tensorflow/python/framework/test_util.py5
-rwxr-xr-xtensorflow/python/keras/BUILD4
-rw-r--r--tensorflow/python/keras/backend.py60
-rw-r--r--tensorflow/python/keras/engine/base_layer.py58
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py249
-rw-r--r--tensorflow/python/keras/engine/network.py230
-rw-r--r--tensorflow/python/keras/engine/saving_test.py17
-rw-r--r--tensorflow/python/keras/engine/topology_test.py96
-rw-r--r--tensorflow/python/keras/engine/training.py272
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py4
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py455
-rw-r--r--tensorflow/python/keras/engine/training_eager.py51
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py50
-rw-r--r--tensorflow/python/keras/engine/training_test.py121
-rw-r--r--tensorflow/python/keras/engine/training_utils.py27
-rw-r--r--tensorflow/python/keras/layers/gru_test.py4
-rw-r--r--tensorflow/python/keras/layers/lstm_test.py7
-rw-r--r--tensorflow/python/keras/layers/recurrent.py337
-rw-r--r--tensorflow/python/keras/layers/simplernn_test.py4
-rw-r--r--tensorflow/python/keras/metrics.py27
-rw-r--r--tensorflow/python/keras/metrics_test.py7
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py21
-rw-r--r--tensorflow/python/keras/models.py43
-rw-r--r--tensorflow/python/keras/models_test.py16
-rw-r--r--tensorflow/python/keras/utils/generic_utils.py6
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_exponential_op_test.py114
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py2
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor.cc69
-rw-r--r--tensorflow/python/lib/core/py_func.cc53
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc3
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py35
-rw-r--r--tensorflow/python/ops/control_flow_ops.py127
-rw-r--r--tensorflow/python/ops/custom_gradient.py10
-rw-r--r--tensorflow/python/ops/linalg/BUILD1
-rw-r--r--tensorflow/python/ops/linalg/linalg_impl.py216
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py15
-rw-r--r--tensorflow/python/pywrap_tfe.i2
-rw-r--r--tensorflow/python/tools/freeze_graph.py3
-rw-r--r--tensorflow/python/training/checkpoint_management.py406
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py316
-rw-r--r--tensorflow/python/training/checkpoint_utils.py3
-rw-r--r--tensorflow/python/training/checkpointable/BUILD4
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py16
-rw-r--r--tensorflow/python/training/monitored_session_test.py5
-rw-r--r--tensorflow/python/training/saver.py401
-rw-r--r--tensorflow/python/training/saver_test.py460
-rw-r--r--tensorflow/python/training/session_manager.py6
-rw-r--r--tensorflow/python/training/session_manager_test.py5
-rw-r--r--tensorflow/python/training/supervisor_test.py3
-rw-r--r--tensorflow/python/training/training.py12
-rw-r--r--tensorflow/python/training/training_util.py8
-rw-r--r--tensorflow/python/util/deprecation.py10
-rw-r--r--tensorflow/python/util/nest_test.py4
-rw-r--r--tensorflow/python/util/tf_inspect.py2
-rw-r--r--tensorflow/python/util/tf_inspect_test.py12
-rw-r--r--tensorflow/python/util/util.cc37
-rw-r--r--tensorflow/stream_executor/blas.h66
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc120
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc3
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc16
-rw-r--r--tensorflow/stream_executor/stream.cc289
-rw-r--r--tensorflow/stream_executor/stream.h39
-rw-r--r--tensorflow/stream_executor/stream_test.cc90
-rw-r--r--tensorflow/tensorflow.bzl4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt2
-rwxr-xr-xtensorflow/tools/ci_build/builds/android.sh8
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh5
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh3
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh8
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh6
-rw-r--r--tensorflow/tools/common/public_api.py1
-rw-r--r--tensorflow/tools/docker/Dockerfile2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl4
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu2
-rw-r--r--tensorflow/tools/docker/README.md6
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/tools/pip_package/setup.py4
-rw-r--r--tensorflow/workspace.bzl1833
536 files changed, 23171 insertions, 9922 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 529707b6fc..161e5f80d4 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -387,6 +387,7 @@ config_setting(
define_values = {
"dynamic_loaded_kernels": "true",
},
+ visibility = ["//visibility:public"],
)
config_setting(
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 7321b4b791..a0a44440c8 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -150,8 +150,8 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::Status::OK();
}
-tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
- TFE_Context** ctx) {
+tensorflow::Status UpdateTFE_ContextWithServerDef(
+ const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@@ -165,12 +165,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
} \
} while (0);
- string worker_name = tensorflow::strings::StrCat(
- "/job:", opts->server_def.job_name(),
- "/replica:0/task:", opts->server_def.task_index());
+ string worker_name =
+ tensorflow::strings::StrCat("/job:", server_def.job_name(),
+ "/replica:0/task:", server_def.task_index());
std::unique_ptr<tensorflow::ServerInterface> server;
- LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
+ LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server.get());
@@ -202,15 +202,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
- remote_workers, rendezvous_id, opts->server_def,
- remote_eager_workers.get(), opts->async, &remote_contexts));
+ remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
+ ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
- session_name, opts->server_def, true));
+ session_name, server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
@@ -221,10 +221,10 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
- *ctx = new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, device_mgr, r, std::move(server),
- std::move(remote_eager_workers),
- std::move(remote_device_mgr), remote_contexts);
+
+ ctx->context.InitializeRemote(
+ std::move(server), std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts, r, device_mgr);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@@ -249,15 +249,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status) {
- if (!options->server_def.ParseFromArray(proto, proto_len)) {
- status->status = tensorflow::errors::InvalidArgument(
- "Invalid tensorflow.ServerDef protocol buffer");
- }
-}
-
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@@ -267,12 +258,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- if (!opts->server_def.job_name().empty()) {
- TFE_Context* ctx = nullptr;
- status->status = NewRemoteAwareTFE_Context(opts, &ctx);
- return ctx;
- }
-
std::vector<tensorflow::Device*> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@@ -301,6 +286,20 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
+// Set server_def on the context, possibly updating it.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!server_def.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid tensorflow.ServerDef protocol buffer");
+ return;
+ }
+ status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
+}
+
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
@@ -348,6 +347,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
int result;
status->status = h->handle->NumDims(&result);
return result;
@@ -355,12 +359,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
tensorflow::Device* d = nullptr;
status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
@@ -368,6 +382,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index ea019a5711..25cf7adbc7 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
-// A tensorflow.ServerDef specifies remote workers (in addition to the current
-// workers name). Operations created on this context can then be executed on
-// any of these remote workers by setting an appropriate device.
-//
-// If the following is set, all servers identified by the
-// ServerDef must be up when the context is created.
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status);
-
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
@@ -127,6 +117,17 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
unsigned char async,
TF_Status* status);
+// A tensorflow.ServerDef specifies remote workers (in addition to the current
+// workers name). Operations created on this context can then be executed on
+// any of these remote workers by setting an appropriate device.
+//
+// If the following is set, all servers identified by the
+// ServerDef must be up when the context is created.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
// Causes the calling thread to block till all ops dispatched in async mode
// have been executed. Note that "execution" here refers to kernel execution /
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 4c5077023d..a5c0681e2e 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -59,7 +59,6 @@ struct TFE_ContextOptions {
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
- tensorflow::ServerDef server_def;
};
struct TFE_Context {
@@ -73,23 +72,6 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
- explicit TFE_Context(
- const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy, bool async,
- tensorflow::DeviceMgr* local_device_mgr,
- tensorflow::Rendezvous* rendezvous,
- std::unique_ptr<tensorflow::ServerInterface> server,
- std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
- const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
- remote_contexts)
- : context(opts,
- static_cast<tensorflow::ContextDevicePlacementPolicy>(
- default_policy),
- async, local_device_mgr, rendezvous, std::move(server),
- std::move(remote_eager_workers), std::move(remote_device_mgr),
- remote_contexts) {}
-
tensorflow::EagerContext context;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 0bdea70fe6..00a0a71fca 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -108,14 +108,14 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
-tensorflow::ServerDef GetServerDef(int num_tasks) {
+tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
- server_def.set_job_name("localhost");
+ server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
- job_def->set_name("localhost");
+ job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
@@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return server_def;
}
+tensorflow::ServerDef GetServerDef(int num_tasks) {
+ return GetServerDef("localhost", num_tasks);
+}
+
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
@@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
@@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
@@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
+void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
+ const std::vector<float>& expected_values) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
+ EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
+ memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+
+ for (int i = 0; i < expected_values.size(); i++) {
+ EXPECT_EQ(expected_values[i], actual_values[i])
+ << "Mismatch in expected values at (zero-based) index " << i;
+ }
+}
+
+void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
+ const char* remote_device_name,
+ const char* local_device_name) {
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 =
+ TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
+
+ TFE_DeleteTensorHandle(retval_task0);
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
+void TestRemoteExecuteChangeServerDef(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char local_device_name[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+ CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+
+ // Update the server def with a new set of names (worker instead of
+ // localhost).
+ tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
+ serialized = updated_server_def.SerializeAsString();
+
+ updated_server_def.set_task_index(1);
+ tensorflow::Status s = tensorflow::GrpcServer::Create(
+ updated_server_def, tensorflow::Env::Default(), &worker_server);
+ ASSERT_TRUE(s.ok()) << s.error_message();
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Create a new tensor_handle.
+ TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
+
+ // Check that copying it to the old remote device (named localhost) fails.
+ TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Copying and executing on the new remote device works.
+ const char new_remote_device_name[] =
+ "/job:worker/replica:0/task:1/device:CPU:0";
+ const char new_local_device_name[] =
+ "/job:worker/replica:0/task:0/device:CPU:0";
+
+ auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
+ h0_task0_new, ctx, new_remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteTensorHandle(h0_task0_new);
+ TFE_DeleteTensorHandle(h0_task1_new);
+
+ CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
+ new_local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ TFE_DeleteContext(ctx);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecuteChangeServerDef) {
+ TestRemoteExecuteChangeServerDef(false);
+}
+TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
+ TestRemoteExecuteChangeServerDef(true);
+}
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
@@ -615,6 +760,42 @@ void SetAndGetOpDevices(bool async) {
TF_DeleteStatus(status);
}
+TEST(CAPI, TensorHandleNullptr) {
+ TFE_TensorHandle* h = nullptr;
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(t, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
+ const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(device_name, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
+ int num_dims = TFE_TensorHandleNumDims(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(num_dims, -1);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
+ int dim = TFE_TensorHandleDim(h, 0, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(dim, -1);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+}
+
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index a98f0b00b2..588a45ea43 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -121,6 +121,7 @@ cc_library(
deps = [
":array_grad",
":data_flow_grad",
+ ":image_grad",
":math_grad",
":nn_grad",
],
@@ -332,6 +333,36 @@ tf_cc_test(
)
cc_library(
+ name = "image_grad",
+ srcs = ["gradients/image_grad.cc"],
+ deps = [
+ ":cc_ops",
+ ":cc_ops_internal",
+ ":grad_op_registry",
+ ":gradients",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "gradients_image_grad_test",
+ srcs = ["gradients/image_grad_test.cc"],
+ deps = [
+ ":cc_ops",
+ ":client_session",
+ ":grad_op_registry",
+ ":grad_testutil",
+ ":gradient_checker",
+ ":image_grad",
+ ":testutil",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
name = "math_grad",
srcs = ["gradients/math_grad.cc"],
deps = [
diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc
index ba056a8f3a..0e61089a59 100644
--- a/tensorflow/cc/client/client_session.cc
+++ b/tensorflow/cc/client/client_session.cc
@@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
target_node_names, outputs, run_metadata);
}
+Status ClientSession::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
+ return impl()->session_->MakeCallable(callable_options, out_handle);
+}
+
+Status ClientSession::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status ClientSession::ReleaseCallable(CallableHandle handle) {
+ return impl()->session_->ReleaseCallable(handle);
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h
index 5fb4109f7d..7dd653eec4 100644
--- a/tensorflow/cc/client/client_session.h
+++ b/tensorflow/cc/client/client_session.h
@@ -87,7 +87,33 @@ class ClientSession {
const std::vector<Operation>& run_outputs,
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
- // TODO(keveman): Add support for partial run.
+ /// \brief A handle to a subgraph, created with
+ /// `ClientSession::MakeCallable()`.
+ typedef int64 CallableHandle;
+
+ /// \brief Creates a `handle` for invoking the subgraph defined by
+ /// `callable_options`.
+ /// NOTE: This API is still experimental and may change.
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle);
+
+ /// \brief Invokes the subgraph named by `handle` with the given options and
+ /// input tensors.
+ ///
+ /// The order of tensors in `feed_tensors` must match the order of names in
+ /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
+ /// match the order of names in `CallableOptions::fetch()` when this subgraph
+ /// was created.
+ /// NOTE: This API is still experimental and may change.
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata);
+
+ /// \brief Releases resources associated with the given `handle` in this
+ /// session.
+ /// NOTE: This API is still experimental and may change.
+ Status ReleaseCallable(CallableHandle handle);
private:
class Impl;
diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc
index ea5cf5a1f1..559ffea7e8 100644
--- a/tensorflow/cc/client/client_session_test.cc
+++ b/tensorflow/cc/client/client_session_test.cc
@@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) {
test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({-1, 2}, {2}));
}
+TEST(ClientSessionTest, Callable) {
+ Scope root = Scope::NewRootScope();
+ auto a = Placeholder(root, DT_INT32);
+ auto b = Placeholder(root, DT_INT32);
+ auto c = Add(root, a, b);
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+
+ CallableOptions options;
+ options.add_feed(a.node()->name());
+ options.add_feed(b.node()->name());
+ options.add_fetch(c.node()->name());
+ ClientSession::CallableHandle callable;
+ TF_CHECK_OK(session.MakeCallable(options, &callable));
+ TF_EXPECT_OK(session.RunCallable(
+ callable, {test::AsTensor<int>({1}, {}), test::AsTensor<int>({41}, {})},
+ &outputs, nullptr));
+ test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({42}, {}));
+ TF_EXPECT_OK(session.ReleaseCallable(callable));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc
index de2645cb44..e9f9c59e3a 100644
--- a/tensorflow/cc/framework/gradient_checker.cc
+++ b/tensorflow/cc/framework/gradient_checker.cc
@@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs,
auto y_pos_flat = y_pos[y_idx].flat<Y_T>();
auto y_neg_flat = y_neg[y_idx].flat<Y_T>();
const int64 y_size = y_shapes[y_idx].num_elements();
- const Y_T scale = Y_T{2 * delta};
+ const Y_T scale = 2 * delta;
auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix<JAC_T>();
for (int c = 0; c < y_size; ++c) {
SetJacobian<Y_T, JAC_T>(&jacobian, r * x_stride + unit_dimension,
@@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs,
auto jac_n = jacobian_ns[i].matrix<JAC_T>();
for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) {
for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) {
- *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c)));
+ auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c));
+ // Treat any NaN as max_error and immediately return.
+ // (Note that std::max may ignore NaN arguments.)
+ if (std::isnan(cur_error)) {
+ *max_error = cur_error;
+ return Status::OK();
+ }
+ *max_error = std::max(*max_error, cur_error);
}
}
}
@@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x,
const Output& y, const TensorShape& y_shape, JAC_T* max_error);
INSTANTIATE_GRAD_ERR_TYPE(float, float, float);
+INSTANTIATE_GRAD_ERR_TYPE(double, float, double);
INSTANTIATE_GRAD_ERR_TYPE(double, double, double);
INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float);
INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float);
diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc
index d4f0a7f5ab..8dd762c282 100644
--- a/tensorflow/cc/framework/gradient_checker_test.cc
+++ b/tensorflow/cc/framework/gradient_checker_test.cc
@@ -28,12 +28,14 @@ namespace {
using ops::Complex;
using ops::Const;
+using ops::Div;
using ops::MatMul;
using ops::Placeholder;
using ops::Real;
using ops::Split;
using ops::Square;
using ops::Stack;
+using ops::Sub;
using ops::Unstack;
TEST(GradientCheckerTest, BasicFloat) {
@@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) {
EXPECT_LT(max_error, 1e-4);
}
+// When calculating gradients that are undefined, test we get NaN
+// as the computed error rather than 0.
+TEST(GradientCheckerTest, BasicNan) {
+ Scope scope = Scope::NewRootScope();
+ TensorShape shape({2, 4, 3});
+ auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape));
+ // y = x/(x-x) should always return NaN
+ auto y = Div(scope, x, Sub(scope, x, x));
+ float max_error;
+ TF_ASSERT_OK((ComputeGradientError<float, float, float>(
+ scope, {x}, {shape}, {y}, {shape}, &max_error)));
+ EXPECT_TRUE(std::isnan(max_error));
+}
+
TEST(GradientCheckerTest, MatMulGrad) {
Scope scope = Scope::NewRootScope();
diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc
new file mode 100644
index 0000000000..882709e1e2
--- /dev/null
+++ b/tensorflow/cc/gradients/image_grad.cc
@@ -0,0 +1,74 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/gradients.h"
+#include "tensorflow/cc/ops/image_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace ops {
+namespace {
+
+Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ bool align_corners;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners));
+ // The internal gradient implementation needs the shape of the input image.
+ // x_shape = shape(x)[1:3]
+ // = slice(shape(x), {1}, {3 - 1})
+ auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2});
+ grad_outputs->push_back(internal::ResizeNearestNeighborGrad(
+ scope, grad_inputs[0], x_shape,
+ internal::ResizeNearestNeighborGrad::AlignCorners(align_corners)));
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("ResizeNearestNeighbor", ResizeNearestNeighborGradHelper);
+
+Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ bool align_corners;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners));
+ grad_outputs->push_back(internal::ResizeBilinearGrad(
+ scope, grad_inputs[0], op.input(0),
+ internal::ResizeBilinearGrad::AlignCorners(align_corners)));
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("ResizeBilinear", ResizeBilinearGradHelper);
+
+Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ bool align_corners;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners));
+ grad_outputs->push_back(internal::ResizeBicubicGrad(
+ scope, grad_inputs[0], op.input(0),
+ internal::ResizeBicubicGrad::AlignCorners(align_corners)));
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper);
+
+} // anonymous namespace
+} // namespace ops
+} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc
new file mode 100644
index 0000000000..2e55c7561b
--- /dev/null
+++ b/tensorflow/cc/gradients/image_grad_test.cc
@@ -0,0 +1,157 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/grad_op_registry.h"
+#include "tensorflow/cc/framework/gradient_checker.h"
+#include "tensorflow/cc/framework/testutil.h"
+#include "tensorflow/cc/gradients/grad_testutil.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+using ops::Const;
+using ops::ResizeBicubic;
+using ops::ResizeBilinear;
+using ops::ResizeNearestNeighbor;
+
+class ImageGradTest : public ::testing::Test {
+ protected:
+ ImageGradTest() : scope_(Scope::NewRootScope()) {}
+
+ enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC };
+
+ template <typename T>
+ Tensor MakeData(const TensorShape& data_shape) {
+ DataType data_type = DataTypeToEnum<T>::v();
+ Tensor data(data_type, data_shape);
+ auto data_flat = data.flat<T>();
+ for (int i = 0; i < data_flat.size(); ++i) {
+ data_flat(i) = T(i);
+ }
+ return data;
+ }
+
+ template <typename T>
+ void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape,
+ const bool align_corners, Output* x, Output* y) {
+ *x = Const<T>(scope_, x_data);
+ switch (op_type) {
+ case RESIZE_NEAREST:
+ *y = ResizeNearestNeighbor(
+ scope_, *x, y_shape,
+ ResizeNearestNeighbor::AlignCorners(align_corners));
+ return;
+ case RESIZE_BILINEAR:
+ *y = ResizeBilinear(scope_, *x, y_shape,
+ ResizeBilinear::AlignCorners(align_corners));
+ return;
+ case RESIZE_BICUBIC:
+ *y = ResizeBicubic(scope_, *x, y_shape,
+ ResizeBicubic::AlignCorners(align_corners));
+ return;
+ }
+ assert(false);
+ }
+
+ template <typename T>
+ void TestResizedShapeForType(const OpType op_type, const bool align_corners) {
+ TensorShape x_shape({1, 2, 2, 1});
+ Tensor x_data = MakeData<T>(x_shape);
+ Output x, y;
+ MakeOp<T>(op_type, x_data, {4, 6}, align_corners, &x, &y);
+
+ ClientSession session(scope_);
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session.Run({y}, &outputs));
+ EXPECT_EQ(outputs.size(), 1);
+ EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1}));
+ }
+
+ void TestResizedShape(OpType op_type) {
+ for (const bool align_corners : {true, false}) {
+ TestResizedShapeForType<Eigen::half>(op_type, align_corners);
+ TestResizedShapeForType<float>(op_type, align_corners);
+ TestResizedShapeForType<double>(op_type, align_corners);
+ }
+ }
+
+ template <typename X_T, typename Y_T, typename JAC_T>
+ void TestResizeToSmallerAndAlign(const OpType op_type,
+ const bool align_corners) {
+ TensorShape x_shape({1, 4, 6, 1});
+ Tensor x_data = MakeData<X_T>(x_shape);
+ Output x, y;
+ MakeOp<X_T>(op_type, x_data, {2, 3}, align_corners, &x, &y);
+ JAC_T max_error;
+ TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
+ scope_, x, x_data, y, {1, 2, 3, 1}, &max_error)));
+ EXPECT_LT(max_error, 1e-3);
+ }
+
+ template <typename X_T, typename Y_T, typename JAC_T>
+ void TestResizeToLargerAndAlign(const OpType op_type,
+ const bool align_corners) {
+ TensorShape x_shape({1, 2, 3, 1});
+ Tensor x_data = MakeData<X_T>(x_shape);
+ Output x, y;
+ MakeOp<X_T>(op_type, x_data, {4, 6}, align_corners, &x, &y);
+ JAC_T max_error;
+ TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, JAC_T>(
+ scope_, x, x_data, y, {1, 4, 6, 1}, &max_error)));
+ EXPECT_LT(max_error, 1e-3);
+ }
+
+ template <typename X_T, typename Y_T, typename JAC_T>
+ void TestResize(OpType op_type) {
+ for (const bool align_corners : {true, false}) {
+ TestResizeToSmallerAndAlign<X_T, Y_T, JAC_T>(op_type, align_corners);
+ TestResizeToLargerAndAlign<X_T, Y_T, JAC_T>(op_type, align_corners);
+ }
+ }
+
+ Scope scope_;
+};
+
+TEST_F(ImageGradTest, TestNearestNeighbor) {
+ TestResizedShape(RESIZE_NEAREST);
+ TestResize<float, float, float>(RESIZE_NEAREST);
+ TestResize<double, double, double>(RESIZE_NEAREST);
+}
+
+TEST_F(ImageGradTest, TestBilinear) {
+ TestResizedShape(RESIZE_BILINEAR);
+ TestResize<float, float, float>(RESIZE_BILINEAR);
+ // Note that Y_T is always float for this op. We choose
+ // double for the jacobian to capture the higher precision
+ // between X_T and Y_T.
+ TestResize<double, float, double>(RESIZE_BILINEAR);
+}
+
+TEST_F(ImageGradTest, TestBicubic) {
+ TestResizedShape(RESIZE_BICUBIC);
+ TestResize<float, float, float>(RESIZE_BICUBIC);
+ // Note that Y_T is always float for this op. We choose
+ // double for the jacobian to capture the higher precision
+ // between X_T and Y_T.
+ TestResize<double, float, double>(RESIZE_BICUBIC);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index fef8b8d4d4..d2f803bd18 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -8,28 +8,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
-# Optional runtime utilities for use by code generated by tfcompile.
-cc_library(
- name = "runtime",
- srcs = ["runtime.cc"],
- hdrs = ["runtime.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework_lite",
- ],
-)
-
-tf_cc_test(
- name = "runtime_test",
- srcs = ["runtime_test.cc"],
- deps = [
- ":runtime",
- "//tensorflow/core:framework",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
# Don't depend on this directly; this is only used for the benchmark test
# generated by tf_library.
cc_library(
@@ -53,9 +31,9 @@ cc_library(
],
deps = [
":embedded_protocol_buffers",
- ":runtime", # needed by codegen to print aligned_buffer_bytes
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/tf2xla:tf2xla_proto",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
@@ -238,7 +216,6 @@ test_suite(
tests = [
":benchmark_test",
":codegen_test",
- ":runtime_test",
":test_graph_tfadd_test",
":test_graph_tfunknownop2_test",
":test_graph_tfunknownop3_test",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 28070d60db..8dbe1e11b7 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
-#include "tensorflow/compiler/aot/runtime.h"
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@@ -303,10 +303,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
const size_t arg_bytes_aligned =
- runtime::aligned_buffer_bytes(iarg.data(), iarg.size());
+ cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size());
const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
const size_t temp_bytes_aligned =
- runtime::aligned_buffer_bytes(itemp.data(), itemp.size());
+ cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size());
const size_t temp_bytes_total =
total_buffer_bytes(itemp.data(), itemp.size());
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 5c57fee326..326f73b975 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -16,339 +16,365 @@ tf_library(
)
"""
-load("//tensorflow:tensorflow.bzl",
- "if_android", "tf_cc_test", "tf_copts")
-
-def tf_library(name, graph, config,
- freeze_checkpoint=None, freeze_saver=None,
- cpp_class=None, gen_test=True, gen_benchmark=True,
- visibility=None, testonly=None,
- tfcompile_flags=None,
- tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
- include_standard_runtime_deps=True,
- enable_xla_hlo_profiling=False, deps=None, tags=None):
- """Runs tfcompile to compile a TensorFlow graph into executable code.
-
- Given an invocation of tf_library(name="foo", ...), generates the following
- build targets:
- foo: A cc_library containing the generated header and computation.
- foo_test: A cc_test with simple tests and benchmarks. Only created if
- gen_test=True.
- foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful
- for mobile devices or other platforms that can't compile the
- full test libraries. Only created if gen_benchmark=True.
-
- Args:
- name: The name of the build rule.
- graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
- is expected to be in the human-readable proto text format, otherwise it is
- expected to be in the proto binary format.
- config: File containing tensorflow.tf2xla.Config proto. If the file ends
- in '.pbtxt' it is expected to be in the human-readable proto text format,
- otherwise it is expected to be in the proto binary format.
- freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
- convert variables into constants.
- freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
- binary form, to convert variables into constants.
- cpp_class: The name of the generated C++ class, wrapping the generated
- function. The syntax of this flag is
- [[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
- for referring to a class, where multiple namespaces may precede the class
- name, separated by double-colons. The class will be generated in the
- given namespace(s), or if no namespaces are given, within the global
- namespace.
- gen_test: If True, also generate a cc_test rule that builds a simple
- test and benchmark.
- gen_benchmark: If True, also generate a binary with a simple benchmark.
- Unlike the output of gen_test, this benchmark can be run on android.
- visibility: Bazel build visibility.
- testonly: Bazel testonly attribute.
- tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
- tfcompile_tool: The tfcompile binary. A non-default can be passed to
- use a tfcompile built with extra dependencies.
- include_standard_runtime_deps: If True, the standard list of kernel/runtime
- deps is added to deps. If False, deps must contain the full set of deps
- needed by the generated library.
- enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
- and emit metadata that lets us pretty-print the gathered profile counters.
- deps: a list of deps to include on the build rules for the generated
- library, added to the standard deps if standard_runtime_deps is True.
- tags: tags to apply to subsidiary build rules.
-
- The output header is called <name>.h.
- """
- if not cpp_class:
- fail("cpp_class must be specified")
-
- tfcompile_graph = graph
- if freeze_checkpoint or freeze_saver:
- if not freeze_checkpoint:
- fail("freeze_checkpoint must be specified when freeze_saver is specified")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "if_android",
+ "tf_cc_test",
+ "tf_copts",
+)
- freeze_name = "freeze_" + name
- freeze_file = freeze_name + ".pb"
+def tf_library(
+ name,
+ graph,
+ config,
+ freeze_checkpoint = None,
+ freeze_saver = None,
+ cpp_class = None,
+ gen_test = True,
+ gen_benchmark = True,
+ visibility = None,
+ testonly = None,
+ tfcompile_flags = None,
+ tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
+ include_standard_runtime_deps = True,
+ enable_xla_hlo_profiling = False,
+ deps = None,
+ tags = None):
+ """Runs tfcompile to compile a TensorFlow graph into executable code.
- # First run tfcompile to generate the list of out_nodes.
- out_nodes_file = "out_nodes_" + freeze_name
- native.genrule(
- name=("gen_" + out_nodes_file),
- srcs=[config],
- outs=[out_nodes_file],
- cmd=("$(location " + tfcompile_tool + ")" +
- " --config=$(location " + config + ")" +
- " --dump_fetch_nodes > $@"),
- tools=[tfcompile_tool],
- # Run tfcompile on the build host, rather than forge, since it's
- # typically way faster on the local machine.
- local=1,
- tags=tags,
- )
+ Given an invocation of tf_library(name="foo", ...), generates the following
+ build targets:
+ foo: A cc_library containing the generated header and
+ computation.
+ foo_test: A cc_test with simple tests and benchmarks. Only created if
+ gen_test=True.
+ foo_benchmark: A cc_binary that runs a minimal-dependency benchmark,
+ useful for mobile devices or other platforms that can't
+ compile the full test libraries. Only created if
+ gen_benchmark=True.
+ The output header is called <name>.h.
- # Now run freeze_graph to convert variables into constants.
- freeze_args = (" --input_graph=$(location " + graph + ")" +
- " --checkpoint_version=1" +
- " --input_binary=" + str(not graph.endswith(".pbtxt")) +
- " --input_checkpoint=$(location " + freeze_checkpoint + ")" +
- " --output_graph=$(location " + freeze_file + ")" +
- " --output_node_names=$$(<$(location " + out_nodes_file +
- "))")
- freeze_saver_srcs = []
- if freeze_saver:
- freeze_args += " --input_saver=$(location " + freeze_saver + ")"
- freeze_saver_srcs += [freeze_saver]
- native.genrule(
- name=freeze_name,
- srcs=[
- graph,
- freeze_checkpoint,
- out_nodes_file,
- ] + freeze_saver_srcs,
- outs=[freeze_file],
- cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
- freeze_args),
- tools=["//tensorflow/python/tools:freeze_graph"],
- tags=tags,
- )
- tfcompile_graph = freeze_file
+ Args:
+ name: The name of the build rule.
+ graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt'
+ it is expected to be in the human-readable proto text format, otherwise
+ it is expected to be in the proto binary format.
+ config: File containing tensorflow.tf2xla.Config proto. If the file ends
+ in '.pbtxt' it is expected to be in the human-readable proto text
+ format, otherwise it is expected to be in the proto binary format.
+ freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
+ convert variables into constants.
+ freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
+ binary form, to convert variables into constants.
+ cpp_class: The name of the generated C++ class, wrapping the generated
+ function. The syntax of this flag is
+ [[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
+ for referring to a class, where multiple namespaces may precede the
+ class name, separated by double-colons. The class will be generated in
+ the given namespace(s), or if no namespaces are given, within the global
+ namespace.
+ gen_test: If True, also generate a cc_test rule that builds a simple
+ test and benchmark.
+ gen_benchmark: If True, also generate a binary with a simple benchmark.
+ Unlike the output of gen_test, this benchmark can be run on android.
+ visibility: Bazel build visibility.
+ testonly: Bazel testonly attribute.
+ tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
+ tfcompile_tool: The tfcompile binary. A non-default can be passed to
+ use a tfcompile built with extra dependencies.
+ include_standard_runtime_deps: If True, the standard list of
+ kernel/runtime deps is added to deps. If False, deps must contain the
+ full set of deps needed by the generated library.
+ enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated
+ program, and emit metadata that lets us pretty-print the gathered
+ profile counters.
+ deps: a list of deps to include on the build rules for the generated
+ library, added to the standard deps if standard_runtime_deps is True.
+ tags: tags to apply to subsidiary build rules.
+ """
+ if not cpp_class:
+ fail("cpp_class must be specified")
- # Rule that runs tfcompile to produce the header and object file.
- header_file = name + ".h"
- metadata_object_file = name + "_tfcompile_metadata.o"
- function_object_file = name + "_tfcompile_function.o"
- ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
- if type(tfcompile_flags) == type(""):
- flags = tfcompile_flags
- else:
- flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
- if enable_xla_hlo_profiling:
- profiling_flag = "--xla_hlo_profile"
- else:
- profiling_flag = ""
- native.genrule(
- name=("gen_" + name),
- srcs=[
- tfcompile_graph,
- config,
- ],
- outs=[
- header_file,
- metadata_object_file,
- function_object_file,
- ],
- cmd=("$(location " + tfcompile_tool + ")" +
- " --graph=$(location " + tfcompile_graph + ")" +
- " --config=$(location " + config + ")" +
- " --entry_point=" + ep +
- " --cpp_class=" + cpp_class +
- " --target_triple=" + target_llvm_triple() +
- " --out_header=$(@D)/" + header_file +
- " --out_metadata_object=$(@D)/" + metadata_object_file +
- " --out_function_object=$(@D)/" + function_object_file +
- " " + flags + " " + profiling_flag),
- tools=[tfcompile_tool],
- visibility=visibility,
- testonly=testonly,
- # Run tfcompile on the build host since it's typically faster on the local
- # machine.
- #
- # Note that setting the local=1 attribute on a *test target* causes the
- # test infrastructure to skip that test. However this is a genrule, not a
- # test target, and runs with --genrule_strategy=forced_forge, meaning the
- # local=1 attribute is ignored, and the genrule is still run.
- #
- # https://www.bazel.io/versions/master/docs/be/general.html#genrule
- local=1,
- tags=tags,
- )
+ tfcompile_graph = graph
+ if freeze_checkpoint or freeze_saver:
+ if not freeze_checkpoint:
+ fail("freeze_checkpoint must be specified when freeze_saver is " +
+ "specified")
- # Rule that runs tfcompile to produce the SessionModule proto, useful for
- # debugging. TODO(b/64813587): Once the SessionModule proto is
- # deterministic, move this into the main rule above.
- session_module_pb = name + "_session_module.pb"
- native.genrule(
- name=(name + "_session_module"),
- srcs=[
- tfcompile_graph,
- config,
- ],
- outs=[
- session_module_pb,
- ],
- cmd=("$(location " + tfcompile_tool + ")" +
- " --graph=$(location " + tfcompile_graph + ")" +
- " --config=$(location " + config + ")" +
- " --entry_point=" + ep +
- " --cpp_class=" + cpp_class +
- " --target_triple=" + target_llvm_triple() +
- " --out_session_module=$(@D)/" + session_module_pb +
- " " + flags),
- tools=[tfcompile_tool],
- visibility=visibility,
- testonly=testonly,
- local=1,
- tags=tags,
- )
+ freeze_name = "freeze_" + name
+ freeze_file = freeze_name + ".pb"
- # The cc_library rule packaging up the header and object file, and needed
- # kernel implementations.
- need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
- native.cc_library(
- name=name,
- srcs=[function_object_file, metadata_object_file],
- hdrs=[header_file],
- visibility=visibility,
- testonly=testonly,
- deps = [
- # These deps are required by all tf_library targets even if
- # include_standard_runtime_deps is False. Without them, the
- # generated code will fail to compile.
- "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
- "//tensorflow/core:framework_lite",
- ] + (need_xla_data_proto and [
- # If we're generating the program shape, we must depend on the proto.
- "//tensorflow/compiler/xla:xla_data_proto",
- ] or []) + (enable_xla_hlo_profiling and [
- "//tensorflow/compiler/xla/service:hlo_profile_printer_data"
- ] or []) + (include_standard_runtime_deps and [
- # TODO(cwhipkey): only depend on kernel code that the model actually needed.
- "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
- "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
- "//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
- "//tensorflow/compiler/xla/service/cpu:runtime_matmul",
- "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
- "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
- "//third_party/eigen3",
- ] or []) + (deps or []),
- tags=tags,
- )
+ # First run tfcompile to generate the list of out_nodes.
+ out_nodes_file = "out_nodes_" + freeze_name
+ native.genrule(
+ name = ("gen_" + out_nodes_file),
+ srcs = [config],
+ outs = [out_nodes_file],
+ cmd = ("$(location " + tfcompile_tool + ")" +
+ " --config=$(location " + config + ")" +
+ " --dump_fetch_nodes > $@"),
+ tools = [tfcompile_tool],
+ # Run tfcompile on the build host, rather than forge, since it's
+ # typically way faster on the local machine.
+ local = 1,
+ tags = tags,
+ )
- # Variables used for gen_test and gen_benchmark.
- no_ns_name = ""
- cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
- if len(cpp_class_split) == 1:
- no_ns_name = cpp_class_split[0]
- else:
- no_ns_name = cpp_class_split[1]
- sed_replace = (
- "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
- "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
- "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
+ # Now run freeze_graph to convert variables into constants.
+ freeze_args = (
+ " --input_graph=$(location " + graph + ")" +
+ " --checkpoint_version=1" +
+ " --input_binary=" + str(not graph.endswith(".pbtxt")) +
+ " --input_checkpoint=$(location " + freeze_checkpoint + ")" +
+ " --output_graph=$(location " + freeze_file + ")" +
+ " --output_node_names=$$(<$(location " + out_nodes_file +
+ "))"
+ )
+ freeze_saver_srcs = []
+ if freeze_saver:
+ freeze_args += " --input_saver=$(location " + freeze_saver + ")"
+ freeze_saver_srcs += [freeze_saver]
+ native.genrule(
+ name = freeze_name,
+ srcs = [
+ graph,
+ freeze_checkpoint,
+ out_nodes_file,
+ ] + freeze_saver_srcs,
+ outs = [freeze_file],
+ cmd = ("$(location " +
+ "//tensorflow/python/tools:freeze_graph)" +
+ freeze_args),
+ tools = ["//tensorflow/python/tools:freeze_graph"],
+ tags = tags,
+ )
+ tfcompile_graph = freeze_file
- if gen_test:
- test_name = name + "_test"
- test_file = test_name + ".cc"
- # Rule to rewrite test.cc to produce the test_file.
+ # Rule that runs tfcompile to produce the header and object file.
+ header_file = name + ".h"
+ metadata_object_file = name + "_tfcompile_metadata.o"
+ function_object_file = name + "_tfcompile_function.o"
+ ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
+ if type(tfcompile_flags) == type(""):
+ flags = tfcompile_flags
+ else:
+ flags = " ".join([
+ "'" + arg.replace("'", "'\\''") + "'"
+ for arg in (tfcompile_flags or [])
+ ])
+ if enable_xla_hlo_profiling:
+ profiling_flag = "--xla_hlo_profile"
+ else:
+ profiling_flag = ""
native.genrule(
- name=("gen_" + test_name),
- testonly=1,
- srcs=[
- "//tensorflow/compiler/aot:test.cc",
+ name = ("gen_" + name),
+ srcs = [
+ tfcompile_graph,
+ config,
+ ],
+ outs = [
header_file,
+ metadata_object_file,
+ function_object_file,
],
- outs=[test_file],
- cmd=("sed " + sed_replace +
- " $(location //tensorflow/compiler/aot:test.cc) " +
- "> $(OUTS)"),
- tags=tags,
- )
-
- # The cc_test rule for the generated code. To ensure that this works
- # reliably across build configurations, we must use tf_cc_test instead of
- # native.cc_test. This is related to how we build
- # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
- # for more details.
- tf_cc_test(
- name=test_name,
- srcs=[test_file],
- deps=[
- ":" + name,
- "//tensorflow/compiler/aot:runtime",
- "//tensorflow/compiler/aot:tf_library_test_main",
- "//tensorflow/compiler/xla:executable_run_options",
- "//third_party/eigen3",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- ],
- tags=tags,
+ cmd = ("$(location " + tfcompile_tool + ")" +
+ " --graph=$(location " + tfcompile_graph + ")" +
+ " --config=$(location " + config + ")" +
+ " --entry_point=" + ep +
+ " --cpp_class=" + cpp_class +
+ " --target_triple=" + target_llvm_triple() +
+ " --out_header=$(@D)/" + header_file +
+ " --out_metadata_object=$(@D)/" + metadata_object_file +
+ " --out_function_object=$(@D)/" + function_object_file +
+ " " + flags + " " + profiling_flag),
+ tools = [tfcompile_tool],
+ visibility = visibility,
+ testonly = testonly,
+ # Run tfcompile on the build host since it's typically faster on the
+ # local machine.
+ #
+ # Note that setting the local=1 attribute on a *test target* causes the
+ # test infrastructure to skip that test. However this is a genrule, not
+ # a test target, and runs with --genrule_strategy=forced_forge, meaning
+ # the local=1 attribute is ignored, and the genrule is still run.
+ #
+ # https://www.bazel.io/versions/master/docs/be/general.html#genrule
+ local = 1,
+ tags = tags,
)
- if gen_benchmark:
- benchmark_name = name + "_benchmark"
- benchmark_file = benchmark_name + ".cc"
- benchmark_main = ("//tensorflow/compiler/aot:" +
- "benchmark_main.template")
-
- # Rule to rewrite benchmark.cc to produce the benchmark_file.
+ # Rule that runs tfcompile to produce the SessionModule proto, useful for
+ # debugging. TODO(b/64813587): Once the SessionModule proto is
+ # deterministic, move this into the main rule above.
+ session_module_pb = name + "_session_module.pb"
native.genrule(
- name=("gen_" + benchmark_name),
- srcs=[
- benchmark_main,
- header_file,
+ name = (name + "_session_module"),
+ srcs = [
+ tfcompile_graph,
+ config,
],
+ outs = [
+ session_module_pb,
+ ],
+ cmd = ("$(location " + tfcompile_tool + ")" +
+ " --graph=$(location " + tfcompile_graph + ")" +
+ " --config=$(location " + config + ")" +
+ " --entry_point=" + ep +
+ " --cpp_class=" + cpp_class +
+ " --target_triple=" + target_llvm_triple() +
+ " --out_session_module=$(@D)/" + session_module_pb +
+ " " + flags),
+ tools = [tfcompile_tool],
+ visibility = visibility,
testonly = testonly,
- outs=[benchmark_file],
- cmd=("sed " + sed_replace +
- " $(location " + benchmark_main + ") " +
- "> $(OUTS)"),
- tags=tags,
+ local = 1,
+ tags = tags,
)
- # The cc_benchmark rule for the generated code. This does not need the
- # tf_cc_binary since we (by deliberate design) do not depend on
- # //tensorflow/core:lib.
- #
- # Note: to get smaller size on android for comparison, compile with:
- # --copt=-fvisibility=hidden
- # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
- # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
- native.cc_binary(
- name=benchmark_name,
- srcs=[benchmark_file],
+ # The cc_library rule packaging up the header and object file, and needed
+ # kernel implementations.
+ need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
+ native.cc_library(
+ name = name,
+ srcs = [function_object_file, metadata_object_file],
+ hdrs = [header_file],
+ visibility = visibility,
testonly = testonly,
- copts = tf_copts(),
- linkopts = if_android(["-pie", "-s"]),
- deps=[
- ":" + name,
- "//tensorflow/compiler/aot:benchmark",
- "//tensorflow/compiler/aot:runtime",
- "//tensorflow/compiler/xla:executable_run_options",
+ deps = [
+ # These deps are required by all tf_library targets even if
+ # include_standard_runtime_deps is False. Without them, the
+ # generated code will fail to compile.
+ "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
+ "//tensorflow/core:framework_lite",
+ ] + (need_xla_data_proto and [
+ # If we're generating the program shape, we must depend on the
+ # proto.
+ "//tensorflow/compiler/xla:xla_data_proto",
+ ] or []) + (enable_xla_hlo_profiling and [
+ "//tensorflow/compiler/xla/service:hlo_profile_printer_data",
+ ] or []) + (include_standard_runtime_deps and [
+ # TODO(cwhipkey): only depend on kernel code that the model actually
+ # needed.
+ "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
+ "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_matmul",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//third_party/eigen3",
- ] + if_android([
- "//tensorflow/compiler/aot:benchmark_extra_android",
- ]),
- tags=tags,
+ ] or []) + (deps or []),
+ tags = tags,
+ )
+
+ # Variables used for gen_test and gen_benchmark.
+ cpp_class_split = cpp_class.rsplit("::", maxsplit = 2)
+ if len(cpp_class_split) == 1:
+ no_ns_name = cpp_class_split[0]
+ else:
+ no_ns_name = cpp_class_split[1]
+ sed_replace = (
+ "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
+ "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
+ "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" "
)
+ if gen_test:
+ test_name = name + "_test"
+ test_file = test_name + ".cc"
+
+ # Rule to rewrite test.cc to produce the test_file.
+ native.genrule(
+ name = ("gen_" + test_name),
+ testonly = 1,
+ srcs = [
+ "//tensorflow/compiler/aot:test.cc",
+ header_file,
+ ],
+ outs = [test_file],
+ cmd = (
+ "sed " + sed_replace +
+ " $(location //tensorflow/compiler/aot:test.cc) " +
+ "> $(OUTS)"
+ ),
+ tags = tags,
+ )
+
+ # The cc_test rule for the generated code. To ensure that this works
+ # reliably across build configurations, we must use tf_cc_test instead
+ # of native.cc_test. This is related to how we build
+ # //tensorflow/core:lib -- see the note in
+ # tensorflow/core/BUILD for more details.
+ tf_cc_test(
+ name = test_name,
+ srcs = [test_file],
+ deps = [
+ ":" + name,
+ "//tensorflow/compiler/aot:tf_library_test_main",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//third_party/eigen3",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+ tags = tags,
+ )
+
+ if gen_benchmark:
+ benchmark_name = name + "_benchmark"
+ benchmark_file = benchmark_name + ".cc"
+ benchmark_main = ("//tensorflow/compiler/aot:" +
+ "benchmark_main.template")
+
+ # Rule to rewrite benchmark.cc to produce the benchmark_file.
+ native.genrule(
+ name = ("gen_" + benchmark_name),
+ srcs = [
+ benchmark_main,
+ header_file,
+ ],
+ testonly = testonly,
+ outs = [benchmark_file],
+ cmd = ("sed " + sed_replace +
+ " $(location " + benchmark_main + ") " +
+ "> $(OUTS)"),
+ tags = tags,
+ )
+
+ # The cc_benchmark rule for the generated code. This does not need the
+ # tf_cc_binary since we (by deliberate design) do not depend on
+ # //tensorflow/core:lib.
+ #
+ # Note: to get smaller size on android for comparison, compile with:
+ # --copt=-fvisibility=hidden
+ # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
+ # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
+ native.cc_binary(
+ name = benchmark_name,
+ srcs = [benchmark_file],
+ testonly = testonly,
+ copts = tf_copts(),
+ linkopts = if_android(["-pie", "-s"]),
+ deps = [
+ ":" + name,
+ "//tensorflow/compiler/aot:benchmark",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//third_party/eigen3",
+ ] + if_android([
+ "//tensorflow/compiler/aot:benchmark_extra_android",
+ ]),
+ tags = tags,
+ )
+
def target_llvm_triple():
- """Returns the target LLVM triple to be used for compiling the target."""
- # TODO(toddw): Add target_triple for other targets. For details see:
- # http://llvm.org/docs/doxygen/html/Triple_8h_source.html
- return select({
- "//tensorflow:android_armeabi": "armv5-none-android",
- "//tensorflow:android_arm": "armv7-none-android",
- "//tensorflow:android_arm64": "aarch64-none-android",
- "//tensorflow:android_x86": "i686-none-android",
- "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
- "//tensorflow:darwin": "x86_64-none-darwin",
- "//conditions:default": "x86_64-pc-linux",
- })
+ """Returns the target LLVM triple to be used for compiling the target."""
+
+ # TODO(toddw): Add target_triple for other targets. For details see:
+ # http://llvm.org/docs/doxygen/html/Triple_8h_source.html
+ return select({
+ "//tensorflow:android_armeabi": "armv5-none-android",
+ "//tensorflow:android_arm": "armv7-none-android",
+ "//tensorflow:android_arm64": "aarch64-none-android",
+ "//tensorflow:android_x86": "i686-none-android",
+ "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
+ "//tensorflow:darwin": "x86_64-none-darwin",
+ "//conditions:default": "x86_64-pc-linux",
+ })
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index e34347b9d4..d3238c6a5e 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -306,6 +306,7 @@ cc_library(
srcs = [
"build_xla_launch_ops_pass.cc",
"deadness_analysis.cc",
+ "deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
],
@@ -378,10 +379,38 @@ tf_cc_test(
)
tf_cc_test(
- name = "compilation_passes_test",
+ name = "deadness_analysis_test",
size = "small",
srcs = [
+ "deadness_analysis_internal.h",
"deadness_analysis_test.cc",
+ ],
+ deps = [
+ ":common",
+ ":compilation_passes",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:sendrecv_ops",
+ "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cc_test(
+ name = "compilation_passes_test",
+ size = "small",
+ srcs = [
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
],
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index d81e5fe900..8aff87e5e6 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -151,7 +152,11 @@ class SymbolPredicate : public Predicate {
tensor_id_(std::move(tensor_id)),
must_be_true_(must_be_true) {}
- string ToString() const override { return tensor_id_.ToString(); }
+ string ToString() const override {
+ return must_be_true() ? strings::StrCat("*", tensor_id_.ToString())
+ : tensor_id_.ToString();
+ }
+
Kind kind() const override { return Kind::kSymbol; }
// If `must_be_true()` is true this SymbolPredicate represents the proposition
@@ -348,6 +353,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status Populate();
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
+ gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
@@ -563,4 +569,24 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
return Status::OK();
}
+gtl::FlatMap<TensorId, string, TensorId::Hasher>
+DeadnessAnalysisImpl::PredicateMapAsString() const {
+ gtl::FlatMap<TensorId, string, TensorId::Hasher> result;
+ std::vector<TensorId> tensor_ids;
+ for (const auto& kv_pair : predicate_map_) {
+ CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
+ }
+ return result;
+}
+
+namespace deadness_analysis_internal {
+Status ComputePredicates(const Graph& graph,
+ PredicateMapTy* out_predicate_map) {
+ DeadnessAnalysisImpl impl(&graph);
+ TF_RETURN_IF_ERROR(impl.Populate());
+ *out_predicate_map = impl.PredicateMapAsString();
+ return Status::OK();
+}
+} // namespace deadness_analysis_internal
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
new file mode 100644
index 0000000000..cdef405110
--- /dev/null
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
+#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
+
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+namespace deadness_analysis_internal {
+
+// Returns a map describing the predicate each Tensor was mapped to. For
+// testing purposes only.
+using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
+Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
+} // namespace deadness_analysis_internal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 584385cab7..6881095b51 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -439,5 +440,28 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) {
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
}
+TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
+ // Demonstrates why we need the must_be_true bit on SymbolP.
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
+ 0, "receiver");
+ Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
+ ops::Switch sw(root.WithOpName("switch"), value, recv);
+ Output logical_and =
+ ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ deadness_analysis_internal::PredicateMapTy predicate_map;
+ TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(),
+ &predicate_map));
+
+ TensorId logical_and_output_0 = {logical_and.node()->name(),
+ Graph::kControlSlot};
+ EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 38eb6d830f..45d422943c 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -462,6 +462,7 @@ Status MarkForCompilationPass::Run(
VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
+ VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
const FunctionLibraryDefinition* fld = options.flib_def;
std::unique_ptr<DeadnessAnalysis> deadness;
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 08c357c879..0e2cdcf630 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -258,6 +258,7 @@ Status XlaCompilationCache::CompileImpl(
xla::LocalExecutable** executable,
const XlaCompiler::CompileOptions* compile_options,
bool compile_single_op) {
+ CHECK_NE(executable, nullptr);
VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
if (VLOG_IS_ON(2)) {
@@ -329,18 +330,15 @@ Status XlaCompilationCache::CompileImpl(
compile_options ? *compile_options : XlaCompiler::CompileOptions(),
function, args, &entry->compilation_result);
}
+ TF_RETURN_IF_ERROR(entry->compilation_status);
+ CHECK_EQ(entry->executable.get(), nullptr);
+ entry->compilation_status =
+ BuildExecutable(options, entry->compilation_result, &entry->executable);
}
+ TF_RETURN_IF_ERROR(entry->compilation_status);
*compilation_result = &entry->compilation_result;
- if (entry->compilation_status.ok() && executable) {
- if (entry->executable == nullptr) {
- entry->compilation_status = BuildExecutable(
- options, entry->compilation_result, &entry->executable);
- }
- *executable = entry->executable.get();
- }
-
- Status status = entry->compilation_status;
- return status;
+ *executable = entry->executable.get();
+ return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 6ead15da13..422f36d43b 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -400,6 +400,21 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertEqual(75, y.numpy())
self.assertEqual(30, dy.numpy())
+ def testGradientTapeInDefun(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(5.0)
+
+ @function.defun
+ def f():
+ x = constant_op.constant(1.0)
+ with backprop.GradientTape() as tape:
+ y = v0 * x
+ dy = tape.gradient(y, v0)
+ return dy
+
+ dy = f()
+ self.assertEqual(1.0, dy.numpy())
+
def testSliceInDefun(self):
with self.test_scope():
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 16f293891d..c0ea242044 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -62,6 +62,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@@ -101,6 +102,9 @@ class OpTestBuilder {
OpTestBuilder& RandomInput(DataType type);
OpTestBuilder& RandomInput(DataType type, std::vector<int64> dims);
+ // As RandomInput but the values are unique.
+ OpTestBuilder& RandomUniqueInput(DataType type, std::vector<int64> dims);
+
// Sets an attribute.
template <class T>
OpTestBuilder& Attr(StringPiece attr_name, T&& value);
@@ -126,6 +130,7 @@ class OpTestBuilder {
DataType type = DT_INVALID;
bool has_dims = false;
+ bool needs_unique_values = false;
std::vector<int64> dims;
};
@@ -167,6 +172,18 @@ OpTestBuilder& OpTestBuilder::RandomInput(DataType type,
return *this;
}
+OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
+ std::vector<int64> dims) {
+ VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
+ InputDescription input;
+ input.type = type;
+ input.has_dims = true;
+ input.needs_unique_values = true;
+ input.dims = std::move(dims);
+ inputs_.push_back(input);
+ return *this;
+}
+
template <class T>
OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
@@ -289,7 +306,8 @@ class OpTest : public ::testing::Test {
// Returns a tensor filled with random but "reasonable" values from the middle
// of the type's range. If the shape is omitted, a random shape is used.
// TODO(phawkins): generalize this code to a caller-supplied distribution.
- Tensor RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape);
+ Tensor RandomTensor(DataType dtype, bool needs_unique_values,
+ gtl::ArraySlice<int64> shape);
Tensor RandomTensor(DataType dtype);
// Like RandomTensor, but uses values >= 0.
@@ -432,49 +450,90 @@ std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
return dims;
}
-Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
+Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
+ gtl::ArraySlice<int64> shape) {
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
+ gtl::FlatSet<float> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
- test::FillFn<float>(&tensor, [this, &distribution](int i) -> float {
- return distribution(generator());
+ test::FillFn<float>(&tensor, [&](int i) -> float {
+ float generated;
+ do {
+ generated = distribution(generator());
+ } while (needs_unique_values &&
+ !already_generated.insert(generated).second);
+ return generated;
});
break;
}
case DT_DOUBLE: {
+ gtl::FlatSet<double> already_generated;
std::uniform_real_distribution<double> distribution(-1.0, 1.0);
- test::FillFn<double>(&tensor, [this, &distribution](int i) -> double {
- return distribution(generator());
+ test::FillFn<double>(&tensor, [&](int i) -> double {
+ double generated;
+ do {
+ generated = distribution(generator());
+ } while (needs_unique_values &&
+ !already_generated.insert(generated).second);
+ return generated;
});
break;
}
case DT_COMPLEX64: {
+ gtl::FlatSet<std::pair<float, float>> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
- test::FillFn<complex64>(&tensor, [this, &distribution](int i) {
- return complex64(distribution(generator()), distribution(generator()));
+ test::FillFn<complex64>(&tensor, [&](int i) {
+ complex64 generated;
+ do {
+ generated =
+ complex64(distribution(generator()), distribution(generator()));
+ } while (
+ needs_unique_values &&
+ !already_generated
+ .insert(std::make_pair(generated.real(), generated.imag()))
+ .second);
+ return generated;
});
break;
}
case DT_INT32: {
+ gtl::FlatSet<int32> already_generated;
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
- test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
- return distribution(generator());
+ test::FillFn<int32>(&tensor, [&](int i) -> int32 {
+ int32 generated;
+ do {
+ generated = distribution(generator());
+ } while (needs_unique_values &&
+ !already_generated.insert(generated).second);
+ return generated;
});
break;
}
case DT_INT64: {
+ gtl::FlatSet<int64> already_generated;
std::uniform_int_distribution<int64> distribution(-(1LL << 40),
1LL << 40);
- test::FillFn<int64>(&tensor, [this, &distribution](int i) -> int64 {
- return distribution(generator());
+ test::FillFn<int64>(&tensor, [&](int i) -> int64 {
+ int64 generated;
+ do {
+ generated = distribution(generator());
+ } while (needs_unique_values &&
+ !already_generated.insert(generated).second);
+ return generated;
});
break;
}
case DT_BOOL: {
+ gtl::FlatSet<bool> already_generated;
std::bernoulli_distribution distribution;
- test::FillFn<bool>(&tensor, [this, &distribution](int i) -> bool {
- return distribution(generator());
+ test::FillFn<bool>(&tensor, [&](int i) -> bool {
+ bool generated;
+ do {
+ generated = distribution(generator());
+ } while (needs_unique_values &&
+ !already_generated.insert(generated).second);
+ return generated;
});
break;
}
@@ -485,7 +544,7 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
}
Tensor OpTest::RandomTensor(DataType dtype) {
- return RandomTensor(dtype, RandomDims());
+ return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims());
}
Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
@@ -761,7 +820,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
VLOG(1) << "Ignoring oversize dims.";
return kInvalid;
}
- input_tensors.push_back(RandomTensor(input.type, dims));
+ input_tensors.push_back(
+ RandomTensor(input.type, input.needs_unique_values, dims));
}
VLOG(1) << "Input: " << input_tensors.back().DebugString();
}
@@ -960,7 +1020,7 @@ TEST_F(OpTest, ArgMax) {
std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("ArgMax")
- .RandomInput(DT_FLOAT, dims)
+ .RandomUniqueInput(DT_FLOAT, dims)
.Input(test::AsScalar<int32>(reduce_dim))
.Attr("T", DT_FLOAT)
.Attr("Tidx", DT_INT32)
@@ -976,7 +1036,7 @@ TEST_F(OpTest, ArgMin) {
std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("ArgMin")
- .RandomInput(DT_FLOAT, dims)
+ .RandomUniqueInput(DT_FLOAT, dims)
.Input(test::AsScalar<int32>(reduce_dim))
.Attr("T", DT_FLOAT)
.Attr("Tidx", DT_INT32)
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 338943201b..61759fd276 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -92,6 +92,18 @@ cc_library(
)
cc_library(
+ name = "cpu_function_runtime",
+ srcs = ["cpu_function_runtime.cc"],
+ hdrs = ["cpu_function_runtime.h"],
+ deps = [
+ # Keep dependencies to a minimum here; this library is used in every AOT
+ # binary produced by tfcompile.
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:framework_lite",
+ ],
+)
+
+cc_library(
name = "xla_compiled_cpu_function",
srcs = ["xla_compiled_cpu_function.cc"],
hdrs = ["xla_compiled_cpu_function.h"],
@@ -99,12 +111,23 @@ cc_library(
deps = [
# Keep dependencies to a minimum here; this library is used in every AOT
# binary produced by tfcompile.
- "//tensorflow/compiler/aot:runtime",
+ ":cpu_function_runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
],
)
+tf_cc_test(
+ name = "cpu_function_runtime_test",
+ srcs = ["cpu_function_runtime_test.cc"],
+ deps = [
+ ":cpu_function_runtime",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "xla_jit_compiled_cpu_function",
srcs = ["xla_jit_compiled_cpu_function.cc"],
diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc
index 5e74079fc1..2ffad2af8c 100644
--- a/tensorflow/compiler/aot/runtime.cc
+++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,22 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/aot/runtime.h"
-
-#include <stdlib.h>
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
namespace tensorflow {
-namespace tfcompile {
-namespace runtime {
-
namespace {
-
// Inline memory allocation routines here, because depending on '//base' brings
// in libraries which use c++ streams, which adds considerable code size on
// android.
-inline void* aligned_malloc(size_t size, int minimum_alignment) {
+void* aligned_malloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN)
return memalign(minimum_alignment, size);
#elif defined(_WIN32)
@@ -47,7 +41,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) {
#endif
}
-inline void aligned_free(void* aligned_memory) {
+void aligned_free(void* aligned_memory) {
#if defined(_WIN32)
_aligned_free(aligned_memory);
#else
@@ -58,13 +52,13 @@ inline void aligned_free(void* aligned_memory) {
size_t align_to(size_t n, size_t align) {
return (((n - 1) / align) + 1) * align;
}
-
} // namespace
-size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
+namespace cpu_function_runtime {
+size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
+ if (sizes[i] > 0) {
total += align_to(sizes[i], kAlign);
}
}
@@ -73,7 +67,7 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
bool annotate_initialized) {
- const size_t total = aligned_buffer_bytes(sizes, n);
+ const size_t total = AlignedBufferBytes(sizes, n);
void* contiguous = nullptr;
if (total > 0) {
contiguous = aligned_malloc(total, kAlign);
@@ -85,7 +79,9 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] == -1) {
+ if (sizes[i] < 0) {
+ // bufs[i] is either a constant, an entry parameter or a thread local
+ // allocation.
bufs[i] = nullptr;
} else {
bufs[i] = reinterpret_cast<void*>(pos);
@@ -100,7 +96,5 @@ void FreeContiguous(void* contiguous) {
aligned_free(contiguous);
}
}
-
-} // namespace runtime
-} // namespace tfcompile
+} // namespace cpu_function_runtime
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h
index d1a669ceb1..c7b4559c65 100644
--- a/tensorflow/compiler/aot/runtime.h
+++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,25 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file contains utilities to make it easier to invoke functions generated
-// by tfcompile. Usage of these utilities is optional.
-
-#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_
-#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_
+#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_
+#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-namespace tfcompile {
-namespace runtime {
+namespace cpu_function_runtime {
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
-static constexpr size_t kAlign = 64;
+constexpr size_t kAlign = 64;
-// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1
-// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign
-// byte boundaries.
-size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
+// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1
+// values. There are `n` entries in `sizes`. Each buffer is aligned to
+// kAlign byte boundaries.
+size_t AlignedBufferBytes(const intptr_t* sizes, size_t n);
// MallocContiguousBuffers allocates buffers for use by the entry point
// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
@@ -41,8 +37,8 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
// temporary buffers.
//
// A single contiguous block of memory is allocated, and portions of it are
-// parceled out into `bufs`, which must have space for `n` entries. Returns the
-// head of the allocated contiguous block, which should be passed to
+// parceled out into `bufs`, which must have space for `n` entries. Returns
+// the head of the allocated contiguous block, which should be passed to
// FreeContiguous when the buffers are no longer in use.
void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
bool annotate_initialized);
@@ -50,9 +46,7 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
// FreeContiguous frees the contiguous block of memory allocated by
// MallocContiguousBuffers.
void FreeContiguous(void* contiguous);
-
-} // namespace runtime
-} // namespace tfcompile
+} // namespace cpu_function_runtime
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_
+#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_
diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc
index 06ec623eb2..f4f27a1562 100644
--- a/tensorflow/compiler/aot/runtime_test.cc
+++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc
@@ -13,39 +13,37 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/aot/runtime.h"
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
-namespace tfcompile {
-namespace runtime {
namespace {
-TEST(Runtime, AlignmentValue) {
+TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
// The tfcompile runtime also has a requirement that comes from the xla
// generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8
// So any value that we choose must abide by that constraint as well.
- EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment);
+ EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment);
}
-TEST(Runtime, AlignedBufferBytes) {
- EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0);
+TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) {
+ EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0);
static constexpr intptr_t sizesA[1] = {-1};
- EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0);
+ EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3};
- EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64);
+ EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64);
static constexpr intptr_t sizesC[1] = {32};
- EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64);
+ EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
- EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320);
+ EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320);
}
void* add_ptr(void* base, uintptr_t delta) {
@@ -56,48 +54,49 @@ void* add_ptr(void* base, uintptr_t delta) {
// expected nullptrs, and write to each byte of allocated memory. We rely on
// the leak checker to tell us if there's an inconsistency between malloc and
// free. We also check the contiguous property.
-TEST(Runtime, MallocFreeContiguousBuffers) {
+TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test empty sizes.
- void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false);
+ void* base =
+ cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false);
EXPECT_EQ(base, nullptr);
- FreeContiguous(base);
+ cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with 0 sum.
static constexpr intptr_t sizesA[1] = {-1};
void* bufA[1];
- base = MallocContiguousBuffers(sizesA, 1, bufA, false);
+ base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false);
EXPECT_EQ(base, nullptr);
EXPECT_EQ(bufA[0], nullptr);
- FreeContiguous(base);
+ cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with non-0 sum.
static constexpr intptr_t sizesB[1] = {3};
void* bufB[1];
- base = MallocContiguousBuffers(sizesB, 1, bufB, false);
+ base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufB[0], add_ptr(base, 0));
char* bufB0_bytes = static_cast<char*>(bufB[0]);
bufB0_bytes[0] = 'A';
bufB0_bytes[1] = 'B';
bufB0_bytes[2] = 'C';
- FreeContiguous(base);
+ cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with non-0 sum, and annotate_initialized.
static constexpr intptr_t sizesC[1] = {3};
void* bufC[1];
- base = MallocContiguousBuffers(sizesC, 1, bufC, true);
+ base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufC[0], add_ptr(base, 0));
char* bufC0_bytes = static_cast<char*>(bufC[0]);
bufC0_bytes[0] = 'A';
bufC0_bytes[1] = 'B';
bufC0_bytes[2] = 'C';
- FreeContiguous(base);
+ cpu_function_runtime::FreeContiguous(base);
// Test mixed sizes.
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
void* bufD[7];
- base = MallocContiguousBuffers(sizesD, 7, bufD, false);
+ base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr);
@@ -115,10 +114,8 @@ TEST(Runtime, MallocFreeContiguousBuffers) {
}
}
}
- FreeContiguous(base);
+ cpu_function_runtime::FreeContiguous(base);
}
} // namespace
-} // namespace runtime
-} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index 672e19bd93..334459138b 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include <cassert>
-#include "tensorflow/compiler/aot/runtime.h"
namespace tensorflow {
@@ -26,20 +26,29 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
result_index_(static_data.result_index),
args_(new void*[static_data.num_args]),
temps_(new void*[static_data.num_temps]),
+ arg_index_to_temp_index_(new int32[static_data.num_args]),
+ num_args_(static_data.num_args),
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
// Allocate arg and temp buffers.
if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
- alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ alloc_args_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.arg_sizes, static_data.num_args, args_,
/*annotate_initialized=*/false);
}
- alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers(
static_data.temp_sizes, static_data.num_temps, temps_,
/*annotate_initialized=*/true);
+ for (int i = 0; i < static_data.num_temps; i++) {
+ if (static_data.temp_sizes[i] < -1) {
+ int32 param_number = -(static_data.temp_sizes[i] + 2);
+ arg_index_to_temp_index_[param_number] = i;
+ }
+ }
+
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
@@ -50,11 +59,24 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
}
}
+bool XlaCompiledCpuFunction::Run() {
+ // Propagate pointers to the argument buffers into the temps array. Code
+ // generated by XLA discovers the incoming argument pointers from the temps
+ // array.
+ for (int32 i = 0; i < num_args_; i++) {
+ temps_[arg_index_to_temp_index_[i]] = args_[i];
+ }
+ raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
+ profile_counters_);
+ return true;
+}
+
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
+ cpu_function_runtime::FreeContiguous(alloc_args_);
+ cpu_function_runtime::FreeContiguous(alloc_temps_);
delete[] args_;
delete[] temps_;
+ delete[] arg_index_to_temp_index_;
delete[] profile_counters_;
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 48a8c083ca..27cfb354bf 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -60,9 +60,19 @@ class XlaCompiledCpuFunction {
// The raw function to call.
RawFunction raw_function;
- // Cardinality and sizes of arg and temp buffers.
+ // Cardinality and size of arg buffers.
const intptr_t* arg_sizes = nullptr;
size_t num_args = 0;
+
+ // Cardinality and size of temp buffers.
+ //
+ // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
+ //
+ // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
+ // corresponding entry in the temp buffer array needs to be set to null.
+ //
+ // If temp_sizes[i] < -1 then the i'th temp is the entry parameter
+ // -(temp_sizes[i] + 2).
const intptr_t* temp_sizes = nullptr;
size_t num_temps = 0;
@@ -113,11 +123,7 @@ class XlaCompiledCpuFunction {
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
- bool Run() {
- raw_function_(temps_[result_index_], &run_options_,
- const_cast<const void**>(args_), temps_, profile_counters_);
- return true;
- }
+ bool Run();
// Returns the error message from the previous failed Run call.
//
@@ -224,6 +230,17 @@ class XlaCompiledCpuFunction {
void** args_ = nullptr;
void** temps_ = nullptr;
+ // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
+ // XLA generated code to be able to find it.
+ //
+ // For now we need to keep around the args_ array because there is code that
+ // depends on args() returning a void**. However, in the future we may remove
+ // args_ in favor of using temps_ as the sole storage for the arguments.
+ int32* arg_index_to_temp_index_;
+
+ // The number of incoming arguments.
+ int32 num_args_;
+
// Backing memory for individual arg and temp buffers.
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 00ccfb1c78..114a9241bd 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -58,11 +58,15 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
- // Callers don't allocate temporary buffers for parameters. Nor for
- // thread-local buffers, which are lowered to alloca.
- if (allocation.is_entry_computation_parameter() ||
- allocation.is_thread_local()) {
+ if (allocation.is_constant() || allocation.is_thread_local()) {
+ // Constants are lowered to globals. Thread locals are lowered to
+ // allocas.
temp_sizes.push_back(-1);
+ } else if (allocation.is_entry_computation_parameter()) {
+ // Entry computation parameters need some preprocessing in
+ // XlaCompiledCpuFunction::Run. See the comment on
+ // XlaCompiledCpuFunction::StaticData::temp_sizes.
+ temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}
diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc
index 3a744148fb..6ef8168948 100644
--- a/tensorflow/compiler/xla/client/lib/prng.cc
+++ b/tensorflow/compiler/xla/client/lib/prng.cc
@@ -56,7 +56,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
- auto round = [builder](ThreeFry2x32State v, int rotation) {
+ auto round = [](ThreeFry2x32State v, int rotation) {
v[0] = v[0] + v[1];
v[1] = RotateLeftS32(v[1], rotation);
v[1] = v[0] ^ v[1];
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index b1a776b8b8..081fec7ad9 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -98,14 +98,13 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
<< "Computation should have progran shape.";
auto program_shape = computation.proto().program_shape();
- // For every (unbound) parameter that the computation wants, we manufacture
- // some arbitrary data so that we can invoke the computation.
- std::vector<std::unique_ptr<GlobalData>> fake_arguments;
- for (const Shape& parameter : program_shape.parameters()) {
- fake_arguments.push_back(MakeFakeDataOrDie(parameter, client));
- }
-
- return fake_arguments;
+ // Create and run a program which produces a tuple with one element per
+ // parameter, then return the tuple's constituent buffers.
+ std::vector<Shape> param_shapes(program_shape.parameters().begin(),
+ program_shape.parameters().end());
+ auto fake_input_tuple =
+ MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client);
+ return client->DeconstructTuple(*fake_input_tuple).ValueOrDie();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index e7250e11d5..8a6c5fb9a7 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -101,11 +101,14 @@ Status LocalExecutable::ValidateExecutionOptions(
}
}
- // Verify that the device the executable was built for is equivalent to the
- // device it will run on.
- int run_device_ordinal = run_options.device_ordinal() == -1
- ? backend_->default_device_ordinal()
- : run_options.device_ordinal();
+ // Verify that the device the executable was built for is equivalent
+ // to the device it will run on.
+ int run_device_ordinal = run_options.device_ordinal();
+ if (run_device_ordinal == -1) {
+ run_device_ordinal = run_options.stream() != nullptr
+ ? run_options.stream()->parent()->device_ordinal()
+ : backend_->default_device_ordinal();
+ }
TF_ASSIGN_OR_RETURN(bool devices_equivalent,
backend_->devices_equivalent(
run_device_ordinal, build_options_.device_ordinal()));
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 53be5a79c2..1cb61f77fb 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1635,6 +1635,32 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
});
}
+XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates,
+ const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
+ TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
+ GetShape(scatter_indices));
+ TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
+ update_computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferScatterShape(
+ input_shape, scatter_indices_shape, updates_shape,
+ to_apply_shape, dimension_numbers));
+
+ *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
+
+ AddCalledComputation(update_computation, &instr);
+ return AddInstruction(std::move(instr), HloOpcode::kScatter,
+ {input, scatter_indices, updates});
+ });
+}
+
XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
@@ -1681,7 +1707,7 @@ XlaOp XlaBuilder::Reduce(
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
- operand_shape, init_shape, dimensions_to_reduce,
+ {&operand_shape, &init_shape}, dimensions_to_reduce,
called_program_shape));
for (int64 dim : dimensions_to_reduce) {
@@ -2803,6 +2829,13 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
window_bounds);
}
+XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates, const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers) {
+ return input.builder()->Scatter(input, scatter_indices, updates,
+ update_computation, dimension_numbers);
+}
+
void Send(const XlaOp& operand, const ChannelHandle& handle) {
return operand.builder()->Send(operand, handle);
}
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index ae331407d6..8726cc6f93 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -857,6 +857,11 @@ class XlaBuilder {
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ // Enqueues a Scatter node onto the computation.
+ XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates, const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers);
+
// Enqueues a Send node onto the computation for device-to-device
// communication, to send the given operand to a Recv instruction that shares
// the same channel handle.
@@ -1296,6 +1301,10 @@ class XlaBuilder {
friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates,
+ const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers);
friend void Send(const XlaOp& operand, const ChannelHandle& handle);
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
@@ -1977,6 +1986,11 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+// Enqueues a Scatter node onto the computation.
+XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates, const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers);
+
// Enqueues a Send node onto the computation for device-to-device
// communication. This operation sends the given operand to
// a Recv instruction in a different computation that shares the same channel
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 15eeb2ea13..b72d190d54 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -297,7 +297,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
shape.layout().padded_dimensions_size() == 0) {
return false;
}
- CHECK(IsDenseArray(shape));
+ CHECK(IsDenseArray(shape)) << shape.ShortDebugString();
CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 548fbe8a83..356f12ed78 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::strings::Printf;
using tensorflow::strings::StrCat;
namespace xla {
diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc
index fed0e58e66..69ef4f7a2f 100644
--- a/tensorflow/compiler/xla/metric_table_report.cc
+++ b/tensorflow/compiler/xla/metric_table_report.cc
@@ -134,8 +134,7 @@ void MetricTableReport::AppendHeader() {
void MetricTableReport::AppendCategoryTable() {
const std::vector<Category> categories = MakeCategories(&entries_);
- AppendLine("********** categories table **********");
- AppendLine("The left hand side numbers are ", metric_name_, ".");
+ AppendLine("********** categories table for ", metric_name_, " **********");
AppendLine();
double metric_sum = UnaccountedMetric();
@@ -185,8 +184,8 @@ void MetricTableReport::AppendCategoryTable() {
}
void MetricTableReport::AppendEntryTable() {
- AppendLine("********** ", entry_name_, " table **********");
- AppendLine("The left hand side numbers are ", metric_name_, ".");
+ AppendLine("********** ", entry_name_, " table for ", metric_name_,
+ " **********");
AppendLine();
double metric_sum = UnaccountedMetric();
diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD
index 8999cda5ef..d790c4db6c 100644
--- a/tensorflow/compiler/xla/python_api/BUILD
+++ b/tensorflow/compiler/xla/python_api/BUILD
@@ -10,6 +10,8 @@ py_library(
srcs = ["types.py"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:platform",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py
index e036c894f8..57dfce3971 100644
--- a/tensorflow/compiler/xla/python_api/types.py
+++ b/tensorflow/compiler/xla/python_api/types.py
@@ -23,6 +23,7 @@ import collections
import numpy as _np # Avoids becoming a part of public Tensorflow API.
from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.python.framework import dtypes
# Records corresponsence between a XLA primitive type and Python/Numpy types.
#
@@ -40,6 +41,12 @@ TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [
# Maps from XLA primitive types to TypeConversionRecord.
MAP_XLA_TYPE_TO_RECORD = {
+ xla_data_pb2.BF16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.BF16,
+ numpy_dtype=dtypes.bfloat16.as_numpy_dtype,
+ literal_field_name='bf16s',
+ literal_field_type=float),
xla_data_pb2.F16:
TypeConversionRecord(
primitive_type=xla_data_pb2.F16,
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index ad14fe6f2c..862cbeeba6 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2006,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
// Builds a convolution from <options> and runs algebraic simplification on
// the computation. Returns a string description of the result of
// simplification.
- auto build_and_simplify = [&options, this]() -> string {
+ auto build_and_simplify = [&options]() -> string {
HloComputation::Builder b(TestName());
Window window;
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 95b4cb6d2e..51ebc4763b 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
ResolveInternal(data));
for (const auto& shaped_buffer : replicated_buffers) {
std::vector<ShapeIndex> shape_indices;
- ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(),
- [this, &shape_indices](const Shape& /*subshape*/,
- const ShapeIndex& index) {
- shape_indices.push_back(index);
- });
+ ShapeUtil::ForEachSubshape(
+ shaped_buffer->on_device_shape(),
+ [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
+ shape_indices.push_back(index);
+ });
for (const ShapeIndex& index : shape_indices) {
TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
shaped_buffer->device_ordinal()));
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index 32f785a70a..a725351462 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -137,9 +137,9 @@ ENTRY entry {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
- ASSERT_TRUE(instruction->has_sharding());
- TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
- EXPECT_EQ(device, 1);
+ auto device = instruction->sharding_unique_device();
+ ASSERT_TRUE(device);
+ EXPECT_EQ(*device, 1);
}
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index e4d2e73b99..118a11c8de 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -877,8 +877,8 @@ Status BufferAssigner::AssignBuffersForComputation(
// important reuse case where an elementwise instruction reuses one of its
// operand's buffer. This improves locality.
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
- [this, has_sequential_order, &liveness, &post_order_position,
- assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
+ [has_sequential_order, &liveness, &post_order_position, assignment](
+ const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size.
const int64 a_size = assignment->buffer_size_(*a);
const int64 b_size = assignment->buffer_size_(*b);
@@ -1441,9 +1441,9 @@ void BufferAssigner::BuildColocatedBufferSets(
const HloInstruction* while_hlo = instruction;
ShapeUtil::ForEachSubshape(
while_hlo->shape(),
- [this, while_hlo, &points_to_analysis, &buffer_liveness,
- buffer_size, computation, colocated_buffer_sets](
- const Shape& /*subshape*/, const ShapeIndex& index) {
+ [this, while_hlo, &points_to_analysis, buffer_size,
+ colocated_buffer_sets](const Shape& /*subshape*/,
+ const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add while.init.
AddBufferToColocatedSet(while_hlo->operand(0), index,
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 6a7eb85e3b..128eea4828 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -156,9 +156,26 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
codegen_passes.run(module);
- // Construct ObjectFile from machine code buffer.
- return std::unique_ptr<llvm::MemoryBuffer>(
+ std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
+
+ if (VLOG_IS_ON(2)) {
+ llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
+ llvm::object::ObjectFile::createObjectFile(*memory_buffer);
+ if (obj_file) {
+ StatusOr<DisassemblerResult> disasm_result =
+ disassembler_->DisassembleObjectFile(*obj_file.get());
+ if (disasm_result.ok()) {
+ XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text);
+ } else {
+ LOG(WARNING) << "Could not disassemble object file!";
+ }
+ } else {
+ LOG(WARNING) << "Could convert memory buffer to object file!";
+ }
+ }
+
+ return memory_buffer;
}
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index b49ea89896..8cbe9a1b0d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -840,18 +840,29 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
- // Callers don't need to allocate temporary buffers for parameters.
- if (allocation.is_entry_computation_parameter() ||
- allocation.is_constant()) {
- buffer_sizes.push_back(-1);
- continue;
- }
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
+
+ // Callers don't need to allocate anything for constant buffers. They are
+ // lowered to globals.
+ if (allocation.is_constant()) {
+ buffer_sizes.push_back(-1);
+ continue;
+ }
+
+ // Callers don't need to allocate anything for entry computation buffers,
+ // but they do need to stash the pointer to the entry computation buffer
+ // in the temp buffer table. See the comment on
+ // XlaCompiledCpuFunction::StaticData::temp_sizes.
+ if (allocation.is_entry_computation_parameter()) {
+ buffer_sizes.push_back(-allocation.parameter_number() - 2);
+ continue;
+ }
+
buffer_sizes.push_back(allocation.size());
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 81e17a5cd4..946f5124b8 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable(
// guarded by the mutex.
compute_function_ =
reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress()));
+ VLOG(1) << "compute_function_ at address "
+ << reinterpret_cast<void*>(compute_function_);
}
-Status CpuExecutable::AllocateBuffers(
+StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
+ std::vector<OwningDeviceMemory>>>
+CpuExecutable::CreateTempArray(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- std::vector<OwningDeviceMemory>* buffers) {
- CHECK_EQ(buffers->size(), assignment_->Allocations().size());
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ std::vector<se::DeviceMemoryBase> unowning_buffers(
+ assignment_->Allocations().size());
+ std::vector<OwningDeviceMemory> owning_buffers(
+ assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
@@ -84,6 +91,8 @@ Status CpuExecutable::AllocateBuffers(
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
+ unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
+ allocation.param_shape_index());
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
@@ -99,34 +108,34 @@ Status CpuExecutable::AllocateBuffers(
}
int64 buffer_size = allocation.size();
- if (!(*buffers)[i].is_null()) {
+ if (!owning_buffers[i].is_null()) {
VLOG(3) << "buffer #" << i
<< " is in the preallocated result ShapedBuffer";
} else {
- TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
- device_ordinal, buffer_size));
+ TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate(
+ device_ordinal, buffer_size));
+ unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase();
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
- << (*buffers)[i].opaque() << "]";
+ << owning_buffers[i].opaque() << "]";
}
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
- TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size);
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
VLOG(3) << "result index: " << result_slice.index();
- return Status::OK();
+ return {{std::move(unowning_buffers), std::move(owning_buffers)}};
}
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
@@ -136,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction(
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
- // args_array: An array of pointers, each of which points to a parameter.
- // The size of this array is determined by the function's arity
- // (ProgramShape).
- // temps_array: An array of pointers, each of which points to a temporary
- // buffer the computation needs. The size of this array is
- // determined by buffer analysis.
+ // args_array: null
+ // temps_array: An array of pointers, containing pointers to temporary buffers
+ // required by the executable adn pointers to entry computation
+ // parameters.
//
- std::vector<const void*> args_array;
- for (const ShapedBuffer* argument : arguments) {
- args_array.push_back(argument->root_buffer().opaque());
- }
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@@ -169,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction(
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
VLOG(3) << tensorflow::strings::Printf(
- " func(void* result, void* params[%zu], void* temps[%zu], "
+ " func(void* result, void* params[null], void* temps[%zu], "
"uint64 profile_counters[%zu])",
- args_array.size(), buffer_pointers.size(), profile_counters_size);
+ buffer_pointers.size(), profile_counters_size);
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
};
- VLOG(3) << tensorflow::strings::Printf(
- " params = [%s]",
- tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str());
+ VLOG(3) << " params = nullptr";
VLOG(3) << tensorflow::strings::Printf(
" temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
@@ -186,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction(
profile_counters);
}
- compute_function_(result_buffer, run_options, args_array.data(),
- buffer_pointers.data(), profile_counters);
+ compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(),
+ profile_counters);
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -254,21 +255,18 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
-
- TF_RETURN_IF_ERROR(AllocateBuffers(
- memory_allocator, stream->parent()->device_ordinal(), &buffers));
+ std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
- unowning_buffers.reserve(buffers.size());
- for (auto& buffer : buffers) {
- unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
- }
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(),
- arguments, unowning_buffers,
- hlo_execution_profile));
+ TF_ASSIGN_OR_RETURN(
+ std::tie(unowning_buffers, owning_buffers),
+ CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
+
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(
+ &run_options->run_options(), unowning_buffers, hlo_execution_profile));
- return CreateResultShapedBuffer(run_options, &buffers);
+ return CreateResultShapedBuffer(run_options, &owning_buffers);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -284,17 +282,15 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
- TF_RETURN_IF_ERROR(AllocateBuffers(
- memory_allocator, stream->parent()->device_ordinal(), &buffers));
-
+ std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
- unowning_buffers.reserve(buffers.size());
- for (auto& buffer : buffers) {
- unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
- }
+ TF_ASSIGN_OR_RETURN(
+ std::tie(unowning_buffers, owning_buffers),
+ CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
+
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
- CreateResultShapedBuffer(run_options, &buffers));
+ CreateResultShapedBuffer(run_options, &owning_buffers));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
@@ -312,7 +308,6 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
struct AsyncRunTask {
CpuExecutable* executable;
ServiceExecutableRunOptions run_options;
- std::vector<const ShapedBuffer*> arguments;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
@@ -320,15 +315,14 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
- &run_options.run_options(), arguments, unowning_buffers,
+ &run_options.run_options(), unowning_buffers,
/*hlo_execution_profile=*/nullptr));
}
};
- host_stream->EnqueueTask(AsyncRunTask{
- this, *run_options,
- std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()),
- unowning_buffers,
- std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))});
+ host_stream->EnqueueTask(
+ AsyncRunTask{this, *run_options, std::move(unowning_buffers),
+ std::make_shared<std::vector<OwningDeviceMemory>>(
+ std::move(owning_buffers))});
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 8dd47bfb86..8af8a5dfec 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -85,20 +85,29 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
- // Allocate buffers required for execution and assign them to the elements of
- // "buffers". "buffers" should be sized to the number of buffers in buffer
- // assignment. Each vector element corresponds to a particular Index. If
- // a vector element already contains a non-null DeviceMemoryBase, then no
- // buffer is assigned for this element.
- Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
- int device_ordinal,
- std::vector<OwningDeviceMemory>* buffers);
+ // Creates an array suitable for passing as the "temps" argument to the JIT
+ // compiled function pointer.
+ //
+ // Returns (unowning_buffers, owning_buffers) where:
+ //
+ // - unowning_buffers.data() can be passed as the temps argument as-is and
+ // includes pointers to the scratch storage required by the computation,
+ // the live-out buffer into which the result will be written and entry
+ // computation parameters.
+ //
+ // - owning_buffers contains owning pointers to the buffers that were
+ // allocated by this routine. This routine allocates buffers for temporary
+ // storage and the live-out buffer into which the computation writes it
+ // result.
+ StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
+ std::vector<OwningDeviceMemory>>>
+ CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 54c52bc08f..639064040f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -92,9 +92,10 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
-void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
- const void* shape,
- xla::int32 shape_length) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
+__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
+ const void* shape,
+ xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "AcquireInfeedBufferForDequeue: "
<< ShapeString(shape, shape_length);
@@ -111,9 +112,11 @@ void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
return buffer->data();
}
-void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
+ void* buffer_ptr,
+ const void* shape_ptr,
+ xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
<< ShapeString(shape_ptr, shape_length);
@@ -125,8 +128,10 @@ void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
std::move(shape));
}
-void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
+__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
+ const void* shape_ptr,
+ xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
<< ShapeString(shape_ptr, shape_length);
@@ -143,9 +148,11 @@ void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
return buffer->data();
}
-void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length,
+ void* buffer_ptr,
+ const void* shape_ptr,
+ xla::int32 shape_length) {
if (VLOG_IS_ON(2)) {
LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
<< ShapeString(shape_ptr, shape_length);
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index cf955a8add..c13d36776f 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -117,9 +119,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
ElementwiseSourceIndex(index, *hlo, i)));
operands.push_back(operand_value);
}
- return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
- hlo->to_apply(), operands,
- llvm_ir::IrName(hlo));
+ return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
+ operands, llvm_ir::IrName(hlo));
};
}
return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a6d8551841..ca645d3f1d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -116,6 +116,19 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
computation->root_instruction()->outer_dimension_partitions().size();
}
+ if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
+ TF_ASSIGN_OR_RETURN(
+ computation_root_allocation_,
+ assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
+ }
+
+ for (const HloInstruction* param : computation->parameter_instructions()) {
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
+ assignment_.GetUniqueTopLevelSlice(param));
+ computation_parameter_allocations_[param_slice.allocation()->index()] =
+ param->parameter_number();
+ }
+
InitializeIrFunction(function_name);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
@@ -132,6 +145,8 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
compute_function_.reset();
+ computation_root_allocation_ = BufferAllocation::Slice();
+ computation_parameter_allocations_.clear();
return ir_function;
}
@@ -484,23 +499,11 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
- HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
- llvm::Function* mapped_ir_function =
- FindOrDie(emitted_functions_, map->to_apply());
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : map->operands()) {
- const llvm_ir::IrArray& array = GetIrArrayFor(operand);
- parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_));
- }
- return EmitElementFunctionCall(mapped_ir_function, map->shape(),
- parameter_addresses, "map_function");
-}
-
-Status IrEmitter::HandleMap(HloInstruction* map) {
- return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
- return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
- });
+llvm::Value* IrEmitter::EmitElementalMap(
+ const HloMapInstruction& map_instr,
+ tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ tensorflow::StringPiece name) {
+ return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
@@ -508,9 +511,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
const llvm_ir::IrArray::Index& index) {
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
- HloComputation* function = reduce_window->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// We fold inputs into the accumulator and initialize it to
// the initial value on the reduce_window.
@@ -563,11 +563,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
// We are not in the padding, so carry out the computation.
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value_address =
- input_array.EmitArrayElementAddress(input_index, &b_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce_window->shape(),
- {accumulator_address, input_value_address}, "reducer_function");
+ llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
+ llvm::Value* result = EmitThreadLocalCall(
+ *reduce_window->to_apply(),
+ {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
b_.CreateStore(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -623,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
- // The select and scatter computations should have been emitted previously.
- llvm::Function* select_function =
- FindOrDie(emitted_functions_, select_and_scatter->select());
- llvm::Function* scatter_function =
- FindOrDie(emitted_functions_, select_and_scatter->scatter());
-
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
@@ -733,11 +726,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
- const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* result = EmitElementFunctionCall(
- select_function, output_shape, {selected_value_address, operand_address},
+ llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* result = EmitThreadLocalCall(
+ *select_and_scatter->select(),
+ {b_.CreateLoad(selected_value_address), operand_element},
"select_function");
// If the 'select' function returns false, update the selected value and the
@@ -764,14 +758,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
- llvm::Value* source_value_address =
- source_array.EmitArrayElementAddress(source_index, &b_);
+ llvm::Value* source_value =
+ source_array.EmitReadArrayElement(source_index, &b_);
llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
- llvm::Value* output_value_address =
- output_array.EmitArrayElementAddress(selected_index, &b_);
- llvm::Value* scatter_value = EmitElementFunctionCall(
- scatter_function, source->shape(),
- {output_value_address, source_value_address}, "scatter_function");
+ llvm::Value* output_value =
+ output_array.EmitReadArrayElement(selected_index, &b_);
+ llvm::Value* scatter_value =
+ EmitThreadLocalCall(*select_and_scatter->scatter(),
+ {output_value, source_value}, "scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -1248,46 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
- auto param_number = parameter->parameter_number();
- auto param_shape = parameter->shape();
-
- // We have to access the parameter at offset param_number in the params
- // array. The code generated here is equivalent to this C code:
- //
- // i8* param_address_untyped = params[param_number];
- // Param* param_address_typed = (Param*)param_address_untyped;
- //
- // Where Param is the actual element type of the underlying buffer (for
- // example, float for an XLA F32 element type).
- llvm::Value* params = compute_function_->parameters_arg();
- llvm::Value* param_address_offset =
- llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset);
- param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
- if (is_top_level_computation_ &&
- hlo_module_config_.debug_options()
- .xla_llvm_enable_invariant_load_metadata()) {
- // In the entry computation the parameter slots in the %params argument are
- // invariant through program execution. In computations that are called
- // from the entry computation (via kWhile, kCall and kConditional) the
- // parameter slots are *not* invariant since they're written to by their
- // callers.
- param_address_untyped->setMetadata(
- llvm::LLVMContext::MD_invariant_load,
- llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
- }
-
- llvm::Value* param_address_typed = b_.CreateBitCast(
- param_address_untyped, IrShapeType(param_shape)->getPointerTo());
- emitted_value_[parameter] = param_address_typed;
-
- if (!ShapeUtil::IsOpaque(param_shape)) {
- AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
- AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
- }
-
- VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
- return Status::OK();
+ return EmitTargetAddressForOp(parameter);
}
// Returns true if the relative order of the unreduced dimensions stays the same
@@ -1751,9 +1706,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
- HloComputation* function = reduce->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1793,10 +1745,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
- llvm::Value* input_address =
- arg_array.EmitArrayElementAddress(input_index, &b_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce->shape(), {accumulator_addr, input_address},
+ llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
+ llvm::Value* result = EmitThreadLocalCall(
+ *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
"reduce_function");
b_.CreateStore(result, accumulator_addr);
@@ -1842,6 +1793,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
return Unimplemented("Send-done is not implemented on CPU.");
}
+Status IrEmitter::HandleScatter(HloInstruction*) {
+ return Unimplemented("Scatter is not implemented on CPUs.");
+}
+
Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
@@ -2134,18 +2089,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : call->operands()) {
- parameter_addresses.push_back(GetEmittedValueFor(operand));
- }
-
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
- parameter_addresses, &b_, computation->name(),
+ {}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
@@ -2156,8 +2106,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
call_args, root->shape(), root->outer_dimension_partitions(), &b_,
call_ir_function, computation->name()));
} else {
- EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
- emitted_value_[call], computation->name());
+ EmitGlobalCall(*computation, computation->name());
}
return Status::OK();
@@ -2238,12 +2187,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
const HloInstruction* init = xla_while->operand(0);
emitted_value_[xla_while] = GetEmittedValueFor(init);
- // The called computation should have been emitted previously.
- llvm::Function* condition_ir_function =
- FindOrDie(emitted_functions_, condition);
- llvm::Function* body_ir_function =
- FindOrDie(emitted_functions_, xla_while->while_body());
-
// Generating:
// while (Condition(while_result)) {
// // CopyInsertion pass inserts copies which enable 'while_result' to
@@ -2260,12 +2203,10 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
- llvm::Value* while_result = GetEmittedValueFor(xla_while);
- llvm::Value* while_condition = EmitElementFunctionCall(
- condition_ir_function, condition->root_instruction()->shape(),
- {while_result}, IrName(xla_while, "cond"));
+ EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
llvm::Value* while_predicate = b_.CreateICmpNE(
- while_condition,
+ b_.CreateLoad(
+ GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2280,8 +2221,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
b_.SetInsertPoint(body_bb);
// Calls the body function.
- EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
- IrName(xla_while, "body"));
+ EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
+
// Finishes with a branch back to the header.
b_.CreateBr(header_bb);
@@ -2449,8 +2390,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
- auto true_arg = conditional->operand(1);
- auto false_arg = conditional->operand(2);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
@@ -2472,13 +2411,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
- llvm::Function* true_function =
- FindOrDie(emitted_functions_, true_computation);
- llvm::Function* false_function =
- FindOrDie(emitted_functions_, false_computation);
-
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
- llvm::Value* conditional_result = GetEmittedValueFor(conditional);
// Generating:
// if (pred)
@@ -2495,12 +2428,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
- EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
- conditional_result, IrName(conditional, "_true"));
+ EmitGlobalCall(*conditional->true_computation(),
+ IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.false_block, &b_);
- EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
- conditional_result, IrName(conditional, "_false"));
+ EmitGlobalCall(*conditional->false_computation(),
+ IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
@@ -2701,44 +2634,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
- llvm::Type* element_type = IrShapeType(target_shape);
- // The alignment and number of bytes within the temporary buffer is determined
- // by the maximal shape as determined by buffer assignment.
- const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
- if (allocation.is_thread_local()) {
+ const BufferAllocation& allocation = *slice.allocation();
+ llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
+ if (slice == computation_root_allocation_) {
+ llvm::Argument* retval = compute_function_->result_arg();
+ llvm::AttrBuilder attr_builder;
+ attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
+ attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
+ retval->addAttrs(attr_builder);
+ return retval;
+ }
+
+ auto param_it =
+ computation_parameter_allocations_.find(slice.allocation()->index());
+ if (param_it != computation_parameter_allocations_.end()) {
+ int64 param_number = param_it->second;
+ // We have to access the parameter at offset param_number in the params
+ // array. The code generated here is equivalent to this C code:
+ //
+ // i8* param_address_untyped = params[param_number];
+ // Param* param_address_typed = (Param*)param_address_untyped;
+ //
+ // Where Param is the actual element type of the underlying buffer (for
+ // example, float for an XLA F32 element type).
+ llvm::Value* params = compute_function_->parameters_arg();
+ llvm::Value* param_address_offset =
+ llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
+ llvm::LoadInst* param_address_untyped =
+ b_.CreateLoad(param_address_offset);
+
+ if (!ShapeUtil::IsOpaque(target_shape)) {
+ AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
+ AttachDereferenceableMetadataForLoad(param_address_untyped,
+ target_shape);
+ }
+ return param_address_untyped;
+ }
+
// Thread-local allocations should only be assigned a single buffer.
const auto& assigned_buffers = allocation.assigned_buffers();
CHECK_EQ(1, assigned_buffers.size());
const Shape& shape = assigned_buffers.begin()->first->shape();
- llvm::AllocaInst*& tempbuf_address =
- thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}];
- if (tempbuf_address == nullptr) {
- tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ std::pair<llvm::Function*, BufferAllocation::Slice> key = {
+ compute_function_->function(), slice};
+ auto buf_it = thread_local_buffers_.find(key);
+ if (buf_it == thread_local_buffers_.end()) {
+ llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
IrShapeType(shape),
tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
MinimumAlignmentForShape(target_shape));
+ auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
+ CHECK(it_inserted_pair.second);
+ buf_it = it_inserted_pair.first;
}
- return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
- }
-
- if (allocation.is_constant()) {
- return FindOrDie(constant_buffer_to_global_, allocation.index());
- }
+ return buf_it->second;
+ }();
+ return b_.CreateBitCast(tempbuf_address,
+ IrShapeType(target_shape)->getPointerTo());
+}
+llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape) {
+ const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
- if (is_top_level_computation_ &&
- hlo_module_config_.debug_options()
+ if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
- // In the entry computation the parameter slots in the %params argument are
- // invariant through program execution. In computations that are called
- // from the entry computation (via kWhile, kCall and kConditional) the
- // parameter slots are *not* invariant since they're written to by their
- // callers.
tempbuf_address_base->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
@@ -2753,85 +2718,25 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
return b_.CreateBitCast(tempbuf_address_untyped,
- element_type->getPointerTo());
+ IrShapeType(target_shape)->getPointerTo());
}
-// Emits a function call returning a single array element. Allocates space
-// for a single element_type value, and loads it after call.
-llvm::Value* IrEmitter::EmitElementFunctionCall(
- llvm::Function* function, const Shape& return_shape,
- gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name) {
- llvm::Value* return_value_buffer = EmitArrayFunctionCall(
- function, return_shape, 1, parameter_addresses, name);
- return b_.CreateLoad(
- return_value_buffer,
- AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
-}
-
-// Emits a core function call based on the following pseudo-code.
-//
-// char** parameter_addresses_buffer =
-// allocate buffer with a pointer for each parameter to the function
-// for each parameter index, i.e. for i = 0, ..., #parameters:
-// parameter_addresses_buffer[i] = parameter_addresses[i]
-// call function(return_value_buffer,
-// parameter_addresses_buffer,
-// temps)
-// return return_value_buffer -- address of the return value.
-void IrEmitter::EmitArrayFunctionCallInto(
- llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
- b_.CreateCall(function,
- GetArrayFunctionCallArguments(
- parameter_addresses, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
-}
-
-llvm::Value* IrEmitter::EmitArrayFunctionCall(
- llvm::Function* function, const Shape& return_shape, int64 element_count,
- gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name) {
- llvm::Value* elements =
- llvm::ConstantInt::get(b_.getInt64Ty(), element_count);
- PrimitiveType return_type = return_shape.element_type();
- llvm::Value* return_value_buffer =
- llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
- tensorflow::strings::StrCat(name, "_return_value_address"), &b_,
- MinimumAlignmentForPrimitiveType(return_type));
- EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
- name);
- return return_value_buffer;
+llvm::Value* IrEmitter::EmitTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape) {
+ if (slice.allocation()->is_thread_local()) {
+ return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ } else if (slice.allocation()->is_constant()) {
+ return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
+ } else {
+ return EmitGlobalTempBufferPointer(slice, target_shape);
+ }
}
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
- llvm::Value* addr;
const Shape& target_shape = op->shape();
- if (op == op->parent()->root_instruction()) {
- // For the root node, we write directly to the output buffer of the
- // function.
- llvm::Argument* retval = compute_function_->result_arg();
- if ((ShapeUtil::IsArray(target_shape) &&
- !ShapeUtil::IsZeroElementArray(target_shape)) ||
- (ShapeUtil::IsTuple(target_shape) &&
- !ShapeUtil::IsEmptyTuple(target_shape))) {
- llvm::AttrBuilder attr_builder;
- attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
- attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
- retval->addAttrs(attr_builder);
- }
- addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo());
- } else {
- // For other nodes, we need the temporary buffer allocated for this node to
- // write the result into.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- assignment_.GetUniqueTopLevelSlice(op));
- addr = EmitTempBufferPointer(slice, target_shape);
- }
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+ assignment_.GetUniqueTopLevelSlice(op));
+ llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2936,20 +2841,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
-StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
- PrimitiveType return_type, HloComputation* computation,
- const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
- llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
- std::vector<llvm::Value*> argument_addrs;
- for (auto argument : arguments) {
- llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- argument->getType(), "arg_addr", &b_);
- b_.CreateStore(argument, argument_addr);
- argument_addrs.push_back(argument_addr);
+llvm::Value* IrEmitter::EmitThreadLocalCall(
+ const HloComputation& callee,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ tensorflow::StringPiece name) {
+ const Shape& return_shape = callee.root_instruction()->shape();
+
+ // Lifting this restriction to allow "small" arrays should be easy. Allowing
+ // larger arrays is difficult because we allocate the buffer for this return
+ // value on the stack.
+ CHECK(ShapeUtil::IsScalar(return_shape));
+
+ PrimitiveType return_type = return_shape.element_type();
+
+ std::vector<llvm::Value*> parameter_addrs;
+ for (llvm::Value* parameter : parameters) {
+ CHECK(!parameter->getType()->isPointerTy());
+ llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ parameter->getType(), "arg_addr", &b_);
+ b_.CreateStore(parameter, parameter_addr);
+ parameter_addrs.push_back(parameter_addr);
+ }
+
+ llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(return_type, module_),
+ tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
+ MinimumAlignmentForPrimitiveType(return_type));
+
+ b_.CreateCall(
+ FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+
+ return b_.CreateLoad(return_value_buffer);
+}
+
+void IrEmitter::EmitGlobalCall(const HloComputation& callee,
+ tensorflow::StringPiece name) {
+ b_.CreateCall(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+}
+
+llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
+ const HloComputation& callee) {
+ const HloInstruction* root_inst = callee.root_instruction();
+ if (root_inst->opcode() == HloOpcode::kOutfeed) {
+ return llvm::Constant::getNullValue(b_.getInt8PtrTy());
}
- return EmitElementFunctionCall(llvm_function,
- ShapeUtil::MakeShape(return_type, {}),
- argument_addrs, name);
+
+ const BufferAllocation::Slice root_buffer =
+ assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
+ return EmitTempBufferPointer(root_buffer, root_inst->shape());
}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 03bbb2afb5..c9a1dab62d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -100,14 +100,15 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
- // Emits a call to `computation` with scalar arguments `arguments`.
- StatusOr<llvm::Value*> EmitScalarCall(
- PrimitiveType return_type, HloComputation* computation,
- const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
-
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
+ // Emit code to map one element according to `map_instr`.
+ llvm::Value* EmitElementalMap(
+ const HloMapInstruction& map_instr,
+ tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ tensorflow::StringPiece name);
+
protected:
//
// The following methods implement the DfsHloVisitor interface.
@@ -143,13 +144,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
- Status HandleMap(HloInstruction* map) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
Status HandleIota(HloInstruction* iota) override;
Status HandleRng(HloInstruction* rng) override;
@@ -218,9 +219,18 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// computation function being emitted by this emitter.
llvm::Value* GetTempBuffersArgument();
- // Emits code that computes the address of the given temporary buffer to the
- // function. target_shape is the shape of this temporary buffer.
- // The returned Value's type is a pointer to element_type.
+ // Helper for EmitTempBufferPointer.
+ llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
+
+ // Helper for EmitTempBufferPointer.
+ llvm::Value* EmitThreadLocalTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape);
+
+ // Emits code that computes the address of the given buffer allocation slice.
+ //
+ // TODO(sanjoy): This should be renamed to reflect that it no longer provides
+ // access to just temporaries.
llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
const Shape& target_shape);
@@ -232,44 +242,27 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::StringPiece
function_name_suffix); // Used for LLVM IR register names.
- // Methods that emit a function call.
- // Parameters:
- // function - The LLVM function to call.
- // return_shape - The return shape of the HLO computation that was used to
- // make the function. Not the same as the return type of the function
- // in LLVM, since we use output parameters for the return type.
- // element_count - number of elements to return (array form only).
- // parameter_addresses - pointers to be passed to the function as
- // parameters.
- // name - used for LLVM IR register names.
-
- // Emits a function call, returning a scalar, often an element of a larger
- // array. Returns a Value for the scalar element returned by the function.
- llvm::Value* EmitElementFunctionCall(
- llvm::Function* function, const Shape& return_shape,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ // Emits a call to a thread local function (e.g. to the computation nested
+ // within a reduce or a map). Thread local callees (by definition) only write
+ // to and read from thread local allocations.
+ //
+ // `parameters` holds the *scalar values* that need to be passed to the
+ // callee. The return value is the scalar returned by the callee.
+ llvm::Value* EmitThreadLocalCall(
+ const HloComputation& callee,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
tensorflow::StringPiece name);
- // Array function call emitter. Stores the function's result into a supplied
- // buffer.
- // Parameters:
- // function - The LLVM function to call.
- // parameter_addresses - pointers to be passed to the function as
- // parameters.
- // return_value - pointer to a buffer where the call result is stored.
-
- void EmitArrayFunctionCallInto(
- llvm::Function* function,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name);
-
- // Array function call emitter. Returns a Value for the function's return
- // value buffer address. The return value buffer is alloca'ed by this
- // function.
- llvm::Value* EmitArrayFunctionCall(
- llvm::Function* function, const Shape& return_shape, int64 element_count,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name);
+ // Emits a call to a "global" function (e.g. to the computation nested within
+ // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
+ // the parameters and return values for these computations so there is no need
+ // to explicitly pass parameters or return results.
+ void EmitGlobalCall(const HloComputation& callee,
+ tensorflow::StringPiece name);
+
+ // Returns the buffer to which a global call to `callee` would have written
+ // its result.
+ llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
@@ -408,11 +401,10 @@ class IrEmitter : public DfsHloVisitorWithDefault {
NameUniquer name_uniquer_;
// Map containing all previously emitted computations.
- std::map<HloComputation*, llvm::Function*> emitted_functions_;
+ std::map<const HloComputation*, llvm::Function*> emitted_functions_;
// Map containing all previously emitted thread-local temporary buffers.
- std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
- llvm::AllocaInst*>
+ std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
thread_local_buffers_;
// The following fields track the IR emission state. According to LLVM memory
@@ -422,6 +414,16 @@ class IrEmitter : public DfsHloVisitorWithDefault {
std::unique_ptr<IrFunction> compute_function_;
llvm::IRBuilder<> b_;
+ // The buffer allocation slice for the root of the computation being compiled.
+ // Only relevant for thread local computations.
+ BufferAllocation::Slice computation_root_allocation_;
+
+ // Maps the buffer allocation slices for the parameters to the computation
+ // being compiled to their parameter numbers. Only relevant for thread local
+ // computations.
+ tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
+ computation_parameter_allocations_;
+
// Maps HLO instructions to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx_;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 6aff838462..2db4d000f5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -80,9 +80,16 @@ void IrFunction::Initialize(const string& function_name,
// void function(i8* retval, i8* run_options, i8** params, i8** temps,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
- // retval: points to the returned value.
- // params: address of an array with pointers to parameters.
- // temps: address of an array with pointers to temporary buffers.
+ // For thread local functions:
+ // retval: points to the returned value.
+ // params: address of an array with pointers to parameters.
+ // temps: is null
+ //
+ // For global functions:
+ // retval: is null
+ // params: is null
+ // temps: address of an array with pointers to temporary buffers and entry
+ // computation parameters.
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@@ -196,18 +203,25 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
llvm::IRBuilder<>* b, tensorflow::StringPiece name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
- llvm::Value* parameter_addresses_buffer =
- llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
- for (size_t i = 0; i < parameter_addresses.size(); ++i) {
- llvm::Value* parameter_as_i8ptr =
- b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(
- name, "_parameter_", i, "_address_as_i8ptr")));
- llvm::Value* slot_in_param_addresses =
- b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
- b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
+ llvm::Value* parameter_addresses_buffer;
+
+ if (parameter_addresses.empty()) {
+ parameter_addresses_buffer =
+ llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
+ } else {
+ parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
+ tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
+
+ for (size_t i = 0; i < parameter_addresses.size(); ++i) {
+ llvm::Value* parameter_as_i8ptr =
+ b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
+ AsStringRef(tensorflow::strings::StrCat(
+ name, "_parameter_", i, "_address_as_i8ptr")));
+ llvm::Value* slot_in_param_addresses =
+ b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
+ b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
+ }
}
const auto to_int8_ptr = [=](llvm::Value* ptr) {
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
index d03da46575..a5f34908d7 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -58,13 +59,14 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
// [partition1_dim2_start]
// [partition1_dim2_limit]
//
-void __xla_cpu_runtime_ParallelForkJoin(
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
void* result_ptr, const void* run_options_ptr, const void** params,
void** temps, uint64* prof_counters, int32 num_partitions,
int64* partitions, int32 num_partitioned_dims, void* function_ptr) {
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
<< " num_partitioned_dims: " << num_partitioned_dims;
+ CHECK_EQ(params, nullptr);
CHECK_GT(num_partitions, 1);
CHECK_GT(num_partitioned_dims, 0);
const xla::ExecutableRunOptions* run_options =
@@ -79,9 +81,9 @@ void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
- [i, function, result_ptr, run_options_ptr, params, temps, prof_counters,
+ [i, function, result_ptr, run_options_ptr, temps, prof_counters,
partitions, offset, &bc]() {
- function(result_ptr, run_options_ptr, params, temps,
+ function(result_ptr, run_options_ptr, nullptr, temps,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
index 39b13183ff..a71a85913c 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@@ -77,27 +78,24 @@ void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
} // namespace
-void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr,
- Eigen::half* out, Eigen::half* lhs,
- Eigen::half* rhs, int64 m, int64 n,
- int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
+ const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
+ Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
MatMulImpl<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
-void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out,
- float* lhs, float* rhs, int64 m, int64 n,
- int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
+ const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
MatMulImpl<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
-void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out,
- double* lhs, double* rhs, int64 m,
- int64 n, int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
+ const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
MatMulImpl<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
transpose_rhs);
}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
index f8c8dd5e93..997fdd2ab3 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
@@ -23,6 +23,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
+#include "tensorflow/core/platform/dynamic_annotations.h"
using tensorflow::int32;
using tensorflow::int64;
@@ -74,10 +75,9 @@ void MatMulF64(const void* run_options_ptr, double* out, double* lhs,
} // namespace
-void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
- float* lhs, float* rhs, int64 m, int64 n,
- int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32(
+ const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
// BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
@@ -88,11 +88,11 @@ void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
// Set thread number back to the previous number.
mkl_set_num_threads_local(prev_num_threads);
}
+
// BLAS GEMM API for 64-bit Matrix Multiplication
-void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
- double* lhs, double* rhs, int64 m, int64 n,
- int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64(
+ const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
// BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
@@ -103,22 +103,26 @@ void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
// Set thread number back to the previous number.
mkl_set_num_threads_local(prev_num_threads);
}
-void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
- float* out, float* lhs,
- float* rhs, int64 m, int64 n,
- int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
+ float* out, float* lhs, float* rhs,
+ int64 m, int64 n, int64 k,
+ int32 transpose_lhs,
+ int32 transpose_rhs) {
// Set the thread number to 1 for single threaded excution.
int prev_num_threads = mkl_set_num_threads_local(1);
MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
// Set thread number back to the previous number.
mkl_set_num_threads_local(prev_num_threads);
}
-void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
- double* out, double* lhs,
- double* rhs, int64 m, int64 n,
- int64 k, int32 transpose_lhs,
- int32 transpose_rhs) {
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
+ double* out, double* lhs,
+ double* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
// Set the thread number to 1 for single threaded excution.
int prev_num_threads = mkl_set_num_threads_local(1);
MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
index 17303e2f0d..16692e7f2e 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@@ -71,7 +72,8 @@ void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs,
} // namespace
-void __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_EigenSingleThreadedMatMulF16(
const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
@@ -79,16 +81,22 @@ void __xla_cpu_runtime_EigenSingleThreadedMatMulF16(
transpose_lhs, transpose_rhs);
}
-void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
- const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
- int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr,
+ float* out, float* lhs,
+ float* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
SingleThreadedMatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
-void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
- const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
- int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr,
+ double* out, double* lhs,
+ double* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
SingleThreadedMatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k,
transpose_lhs, transpose_rhs);
}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c433bddc84..c35569c661 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -220,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// The body adds the reduced value of the Infeed data (first tuple element)
// to the previous accumulator, and returns the accumulator and the continue
// flag (second tuple element) as a tuple.
- const auto build_body = [this, &result_shape](const Shape& infeed_shape) {
+ const auto build_body = [&result_shape](const Shape& infeed_shape) {
XlaComputation body;
XlaBuilder builder("body");
auto prev = Parameter(&builder, 0, result_shape, "prev");
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 097fa23027..9f86749125 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -233,6 +233,7 @@ class DfsHloVisitorBase {
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
virtual Status HandleGather(HloInstructionPtr hlo) = 0;
+ virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index f4316e0fb7..ae8a066d62 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -194,6 +194,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
+ Status HandleScatter(HloInstructionPtr scatter) override {
+ return DefaultAction(scatter);
+ }
Status HandleAfterAll(HloInstructionPtr token) override {
return DefaultAction(token);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index f883eb828c..574ae0c903 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2134,7 +2134,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
default:
- return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+ return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
HloOpcodeString(hlo->opcode()).c_str());
};
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index e0aae3866b..4947dd278e 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -636,7 +636,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
- "//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@@ -749,6 +748,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
],
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index b3a3c5dcb4..2fd2206324 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -43,6 +43,8 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
HloExecutionProfiler* profiler) {
+ VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for "
+ << (hlo_instruction() ? hlo_instruction()->ToString() : "<null>");
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (int64 i = 0; i < loop_limit_; ++i) {
profiler->StartHloComputation();
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index dbcbabdc52..74282c568c 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -32,16 +32,19 @@ namespace {
// dimensions.
struct MatrixDescriptor {
MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
- int64 matrix_num_rows, int64 matrix_num_cols)
+ int64 matrix_num_rows, int64 matrix_num_cols,
+ int64 matrix_batch_size)
: data(matrix_data),
transpose(needs_transpose),
num_rows(matrix_num_rows),
- num_cols(matrix_num_cols) {}
+ num_cols(matrix_num_cols),
+ batch_size(matrix_batch_size) {}
se::DeviceMemoryBase data;
bool transpose; // Whether this matrix needs to be transposed.
int64 num_rows;
int64 num_cols;
+ int64 batch_size;
};
// Performs a gemm call without an explicit algorithm on lhs_matrix and
@@ -51,6 +54,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha, se::Stream* stream) {
DCHECK(!output_matrix.transpose);
+ const int64 batch_size = lhs_matrix.batch_size;
+ CHECK_EQ(batch_size, rhs_matrix.batch_size);
+ CHECK_EQ(batch_size, output_matrix.batch_size);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@@ -61,13 +67,30 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
: se::blas::Transpose::kNoTranspose;
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
+ if (batch_size == 1) {
+ return stream
+ ->ThenBlasGemm(
+ lhs_transpose, rhs_transpose, output_matrix.num_rows,
+ output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
+ lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
+ /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
+ &output_data, /*leading dim of output=*/output_matrix.num_rows)
+ .ok();
+ }
+
+ int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols;
+ int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols;
+ int64 output_stride = output_matrix.num_rows * output_matrix.num_cols;
return stream
- ->ThenBlasGemm(
+ ->ThenBlasGemmStridedBatched(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
- output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
- lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
- /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
- &output_data, /*leading dim of output=*/output_matrix.num_rows)
+ output_matrix.num_cols, /*size of reduce dim=*/k,
+ /*alpha=*/alpha, lhs_data,
+ /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data,
+ /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride,
+ /*beta=*/0.0, &output_data,
+ /*leading dim of output=*/output_matrix.num_rows, output_stride,
+ batch_size)
.ok();
}
@@ -94,6 +117,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
se::blas::ProfileResult* output_profile_result) {
DCHECK(!output_matrix.transpose);
+ CHECK_EQ(1, lhs_matrix.batch_size);
+ CHECK_EQ(1, rhs_matrix.batch_size);
+ CHECK_EQ(1, output_matrix.batch_size);
+
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@@ -174,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm<float>) {
return &DoGemm<float>;
case F64:
return &DoGemm<double>;
+ case C64:
+ return &DoGemm<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -187,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type)
return &DoGemmWithAlgorithm<float>;
case F64:
return &DoGemmWithAlgorithm<double>;
+ case C64:
+ return &DoGemmWithAlgorithm<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -199,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune<float>) {
return &DoGemmAutotune<float>;
case F64:
return &DoGemmAutotune<double>;
+ case C64:
+ return &DoGemmAutotune<std::complex<float>>;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -217,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
return se::blas::ComputationType::kF32;
case F64:
return se::blas::ComputationType::kF64;
+ case C64:
+ return se::blas::ComputationType::kComplexF32;
default:
LOG(FATAL) << "Unsupported type.";
}
@@ -270,12 +305,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::DeviceMemoryBase output_data =
buffer_allocations.GetDeviceAddress(output_buffer_);
+ DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
+ dim_nums.rhs_batch_dimensions_size());
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
+ ShapeUtil::Rank(output_shape_));
+
+ int64 row_dim = dim_nums.lhs_batch_dimensions_size();
+ int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
+ int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
+ output_shape_.dimensions().end() - 2, 1,
+ std::multiplies<int64>());
+
+ // Check that the batch dims don't cover the last two dims.
+ for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
+ CHECK_NE(row_dim, batch_dim);
+ CHECK_NE(col_dim, batch_dim);
+ }
+
+ // Verify that the non-batch dimensions are minor-most. This is required for
+ // efficient access.
+ for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
+ CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
+ CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
+ }
+
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
// as column when mapping a matrix Dot to BLAS gemm.
- int64 output_num_rows = output_shape_.dimensions(0);
- int64 output_num_cols = output_shape_.dimensions(1);
+ int64 output_num_rows = output_shape_.dimensions(row_dim);
+ int64 output_num_cols = output_shape_.dimensions(col_dim);
// BLAS gemm expects the inputs and the output are in column-major order.
// Therefore, we need to convert dot between row-major matrices to that
@@ -298,31 +358,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
// the leading dimension of the LHS matrix of gemm is the number of rows in
// B^T and thus the number of columns in B.
- auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape,
- bool transpose) -> MatrixDescriptor {
- bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0;
- bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) !=
- LayoutUtil::Minor(output_shape_.layout(), 0);
- return MatrixDescriptor(data, transpose ^ layout_mismatch,
- shape.dimensions(is_row_major),
- shape.dimensions(!is_row_major));
+ auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
+ bool transpose) -> MatrixDescriptor {
+ bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
+ bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
+ LayoutUtil::Minor(output_shape_.layout(), row_dim);
+ return MatrixDescriptor(
+ data, transpose ^ layout_mismatch,
+ shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
+ shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
+ batch_size);
};
- DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
-
const MatrixDescriptor lhs_descriptor = make_descriptor(
- lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
+ lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
const MatrixDescriptor rhs_descriptor = make_descriptor(
- rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1);
+ rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
// autotune this gemm to figure out the best algorithm.
- auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
- MatrixDescriptor output_matrix, se::Stream* stream) {
+ auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix, se::Stream* stream) {
PrimitiveType element_type = output_shape_.element_type();
se::blas::ComputationType computation_type =
GetBlasComputationType(element_type);
+ // TODO(b/112111608): Implement auto tune for batched gemm.
+ if (batch_size != 1) {
+ return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
+ alpha_, stream);
+ }
+
auto thunk_name = [&] {
return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
: "<null>";
@@ -368,16 +434,16 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
bool launch_ok;
- if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) {
- launch_ok = launch(
- lhs_descriptor, rhs_descriptor,
- MatrixDescriptor(output_data, false, output_num_rows, output_num_cols),
- stream);
+ if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
+ launch_ok = launch(lhs_descriptor, rhs_descriptor,
+ MatrixDescriptor(output_data, false, output_num_rows,
+ output_num_cols, batch_size),
+ stream);
} else {
- launch_ok = launch(
- rhs_descriptor, lhs_descriptor,
- MatrixDescriptor(output_data, false, output_num_cols, output_num_rows),
- stream);
+ launch_ok = launch(rhs_descriptor, lhs_descriptor,
+ MatrixDescriptor(output_data, false, output_num_cols,
+ output_num_rows, batch_size),
+ stream);
}
if (!launch_ok) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index bb71c79fd7..bb7736efa6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -293,7 +293,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
// the respective location in ShapedBuffer.
std::set<se::DeviceMemoryBase> buffers_in_result;
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus(
- [&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
+ [&buffer_allocations, &buffers_in_result, this](
const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
const auto& sources = this->GetRootPointsToSet().element(index);
// The points-to set is unambiguous so the set should be a
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 6ac5dfbcd5..d033faee8d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -176,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints(
TF_RETURN_IF_ERROR(
AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
}
+
+ // For batched dot we require the default layout.
+ // TODO(b/112111608): This is overly conservative, the only real restriction
+ // is that batch dimensions must be major.
+ if (instruction->opcode() == HloOpcode::kDot &&
+ ImplementedAsGemm(*instruction) &&
+ instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
+ // Verify that the batch dims come before the row and col dims.
+ const DotDimensionNumbers& dim_nums =
+ instruction->dot_dimension_numbers();
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
+ dim_nums.rhs_batch_dimensions_size());
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
+ ShapeUtil::Rank(instruction->shape()));
+ for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
+ CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2);
+ }
+
+ // Set both inputs and the output to default layout.
+ Shape op0_shape = instruction->operand(0)->shape();
+ LayoutUtil::SetToDefaultLayout(&op0_shape);
+ Shape op1_shape = instruction->operand(1)->shape();
+ LayoutUtil::SetToDefaultLayout(&op1_shape);
+ Shape output_shape = instruction->shape();
+ LayoutUtil::SetToDefaultLayout(&output_shape);
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(op0_shape, instruction, 0));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(op1_shape, instruction, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(output_shape, instruction));
+ }
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 95f78ae293..286547ebae 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -31,6 +33,8 @@ namespace xla {
namespace gpu {
namespace {
+namespace op = xla::testing::opcode_matchers;
+
using LayoutAssignmentTest = HloTestBase;
TEST_F(LayoutAssignmentTest, Elementwise) {
@@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}
}
+TEST_F(LayoutAssignmentTest, DotLayout) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[8,8,256,64]{3,1,2,0} parameter(0)
+ p1 = f32[8,8,256,64]{3,1,2,0} parameter(1)
+ ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1),
+ lhs_batch_dims={0,1}, lhs_contracting_dims={3},
+ rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ GpuLayoutAssignment layout_assignment(&computation_layout,
+ backend().default_stream_executor());
+ EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
+
+ Shape expected_shape =
+ ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Dot(op::ShapeWithLayout(expected_shape),
+ op::ShapeWithLayout(expected_shape)));
+}
+
} // namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 6352b330d1..c349063c71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -38,24 +38,27 @@ namespace gpu {
namespace {
// Return whether the given shape is a matrix with no padding.
-bool IsRank2WithNoPadding(const Shape& shape) {
- return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
+bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) {
+ return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 &&
+ !LayoutUtil::IsPadded(shape);
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
- const Shape& output_shape) {
+ const Shape& output_shape,
+ int64 batch_dimensions_size) {
// The inputs and the output must
// 1) be matrices with no padding and a non-zero number of elements,
// 2) have an allowed element type.
PrimitiveType output_primitive_type = output_shape.element_type();
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == F32 ||
- output_primitive_type == F64);
- return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
- IsRank2WithNoPadding(rhs_shape) &&
- IsRank2WithNoPadding(output_shape) &&
+ output_primitive_type == F64 || output_primitive_type == C64);
+ return type_is_allowed &&
+ IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) &&
+ IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) &&
+ IsRank2WithNoPadding(output_shape, batch_dimensions_size) &&
!ShapeUtil::IsZeroElementArray(lhs_shape) &&
!ShapeUtil::IsZeroElementArray(rhs_shape);
}
@@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) {
CHECK_EQ(dot.opcode(), HloOpcode::kDot);
const Shape& lhs_shape = dot.operand(0)->shape();
const Shape& rhs_shape = dot.operand(1)->shape();
+ const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
+ dim_numbers.lhs_batch_dimensions_size())) {
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
- const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
return true;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 1295e83c0c..541cacf697 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -125,6 +125,10 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) {
return Unimplemented("Recv-done is not implemented on GPU");
}
+Status IrEmitter::HandleScatter(HloInstruction*) {
+ return Unimplemented("Scatter is not implemented on GPUs.");
+}
+
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
std::vector<llvm::Value*> base_ptrs;
for (const HloInstruction* operand : tuple->operands()) {
@@ -450,6 +454,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const Shape& lhs_shape = lhs_instruction->shape();
const Shape& rhs_shape = rhs_instruction->shape();
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ CHECK_EQ(dnums.lhs_batch_dimensions_size(),
+ dnums.rhs_batch_dimensions_size());
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
llvm::Type* index_type = b_.getInt64Ty();
@@ -485,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const int64 lhs_reduction_dimension =
ShapeUtil::GetDimensionNumber(lhs_shape, -1);
const int64 rhs_reduction_dimension =
- ShapeUtil::Rank(rhs_shape) >= 2
+ ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size()
? ShapeUtil::GetDimensionNumber(rhs_shape, -2)
- : 0;
+ : dnums.lhs_batch_dimensions_size();
+
+ // Check that the batch dims don't cover the last two dims.
+ for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
+ CHECK_NE(lhs_reduction_dimension, batch_dim);
+ CHECK_NE(rhs_reduction_dimension, batch_dim);
+ }
// Verify the reduction dimension in the two operands are the same size.
TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
@@ -502,6 +515,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest(
rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
+ // We don't have to iterate over the batch dimensions in both arrays, simplify
+ // the loop nest of the rhs.
+ for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
+ DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i));
+ rhs_index[i] = lhs_index[i];
+ }
+
// Create the reduction loop which does the sum of products reduction.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
/*start_index=*/0,
@@ -564,7 +584,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
target_index.push_back(lhs_index[dimension]);
}
}
- for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) {
+ // Skip over the batch dimensions to not have them in the index twice.
+ for (size_t dimension = dnums.lhs_batch_dimensions_size();
+ dimension < rhs_index.size(); ++dimension) {
if (dimension != rhs_reduction_dimension) {
target_index.push_back(rhs_index[dimension]);
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 80e2a203ac..561c683879 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -86,6 +86,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleFusion(HloInstruction* fusion) override;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 874c7cfb8a..d5ecae88ed 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -171,40 +171,6 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
return DfsHloVisitor::Postprocess(hlo);
}
-namespace {
-bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment,
- const HloInstruction& hlo) {
- // `hlo` needs to satisfy the following conditions to be implemented as a
- // host-to-device cuMemcpy.
- //
- // 1. `hlo` is a kCopy instruction.
- // 2. `hlo`'s only operand is a kConstant instruction.
- // 3. `hlo` and its operand have the same shape (thus the same layout too).
- // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing
- // pointers in a tuple).
- return hlo.opcode() == HloOpcode::kCopy &&
- hlo.operand(0)->opcode() == HloOpcode::kConstant &&
- ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
- buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok();
-}
-
-bool ImplementedAsDeviceToDeviceMemcpy(
- const BufferAssignment& buffer_assignment, const HloInstruction& hlo) {
- // `hlo` needs to satisfy three conditions to be implemented as a
- // device-to-device cuMemcpy.
- //
- // 1. `hlo` is a kCopy instruction.
- // 2. `hlo` and its operand have the same shape (thus the same layout too).
- // 3. `hlo` and its operand have a statically-known buffer assignment
- // (constants do not, for instance), which means the source buffer also
- // resides on the device.
- return hlo.opcode() == HloOpcode::kCopy &&
- ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) &&
- buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() &&
- buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok();
-}
-} // namespace
-
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
const HloInstruction& inst,
tensorflow::gtl::ArraySlice<const BufferAllocation*> args) {
@@ -379,11 +345,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
- const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
- if (dnums.lhs_batch_dimensions_size() > 0 ||
- dnums.rhs_batch_dimensions_size() > 0) {
- return Unimplemented("Dot with batch dimensions not implemented.");
- }
if (ImplementedAsGemm(*dot)) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
@@ -730,13 +691,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
}
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
- if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
- *copy)) {
- thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy));
- return Status::OK();
- }
- if (ImplementedAsDeviceToDeviceMemcpy(
- ir_emitter_context_->buffer_assignment(), *copy)) {
+ CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
+ const BufferAssignment& buffer_assignment =
+ ir_emitter_context_->buffer_assignment();
+ if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
+ copy->shape().layout()) &&
+ buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index 6c1c20fc04..cf44458a2e 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -114,21 +114,20 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
// Gets the GPU name as it's known to LLVM for a given compute capability. If
// we see an unrecognized compute capability, we return "sm_30".
static string GetSmName(std::pair<int, int> compute_capability) {
- static auto* m = new std::map<std::pair<int, int>, int>(
- {{{2, 0}, 20},
- {{2, 1}, 21},
- {{3, 0}, 30},
- {{3, 2}, 32},
- {{3, 5}, 35},
- {{3, 7}, 37},
- {{5, 0}, 50},
- {{5, 2}, 52},
- {{5, 3}, 53},
- {{6, 0}, 60},
- {{6, 1}, 61},
- {{6, 2}, 62},
- // TODO: Change this to 70 once LLVM NVPTX supports it
- {{7, 0}, 60}});
+ static auto* m = new std::map<std::pair<int, int>, int>({
+ {{3, 0}, 30},
+ {{3, 2}, 32},
+ {{3, 5}, 35},
+ {{3, 7}, 37},
+ {{5, 0}, 50},
+ {{5, 2}, 52},
+ {{5, 3}, 53},
+ {{6, 0}, 60},
+ {{6, 1}, 61},
+ {{6, 2}, 62},
+ {{7, 0}, 70},
+ {{7, 2}, 72},
+ });
int sm_version = 30;
auto it = m->find(compute_capability);
if (it != m->end()) {
@@ -329,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
if (linker.linkInModule(
std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded,
[](Module& M, const StringSet<>& GVS) {
- internalizeModule(M, [&M, &GVS](const GlobalValue& GV) {
+ internalizeModule(M, [&GVS](const GlobalValue& GV) {
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
})) {
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c67dcbce77..c62bae0628 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -115,15 +115,23 @@ bool IsInputFusibleReduction(HloInstruction* instr) {
// will be broadcasted and have not been observed to cause data locality issues.
// TODO(b/111977086): Improve reduce emitters to remove this limitation.
bool ReduceFriendlyInputLayouts(HloInstruction* instr) {
+ std::vector<HloInstruction*> params;
+ if (instr->opcode() == HloOpcode::kFusion) {
+ params = instr->fused_parameters();
+ } else {
+ for (HloInstruction* operand : instr->operands()) {
+ params.push_back(operand);
+ }
+ }
int64 max_rank = 0;
const Layout* max_rank_layout;
- for (HloInstruction* param : instr->fused_parameters()) {
+ for (HloInstruction* param : params) {
if (ShapeUtil::Rank(param->shape()) > max_rank) {
max_rank = ShapeUtil::Rank(param->shape());
max_rank_layout = &param->shape().layout();
}
}
- return c_all_of(instr->fused_parameters(), [&](HloInstruction* param) {
+ return c_all_of(params, [&](HloInstruction* param) {
return (ShapeUtil::Rank(param->shape()) < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
});
@@ -221,7 +229,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
const bool is_loop_fusion =
producer->opcode() == HloOpcode::kFusion &&
producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
- if (!is_loop_fusion) {
+ if (!producer->IsElementwise() && !is_loop_fusion) {
VLOG(3) << producer->name() << " is not a loop fusion.";
continue;
}
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index ec4234b8d9..14f157a5e5 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -256,6 +256,26 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ exp = f32[2,2,2]{2,1,0} exponential(p0)
+ reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Exp()));
+}
+
TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
fused_add {
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 7a683ede54..8fa0439006 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
-#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@@ -148,7 +147,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// support BF16 operations without directly implementing a BF16 lowering for
// most ops.
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
- pipeline.AddPass<DotDecomposer>();
{
auto& pass =
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 63a8a813cd..0b93d97c11 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -160,6 +160,8 @@ message HloInstructionProto {
// present for Send and Recv instructions and their SendDone and RecvDone
// partners.
bool is_host_transfer = 47;
+
+ xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 1f672502f7..a2cefd2621 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -49,9 +49,9 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
// The default number of bytes accessed for an instruction is the sum of the
// sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
// handle opaque types.
- float bytes_accessed = shape_size_(hlo->shape());
+ float bytes_accessed = GetShapeSize(hlo->shape());
for (const HloInstruction* operand : hlo->operands()) {
- bytes_accessed += shape_size_(operand->shape());
+ bytes_accessed += GetShapeSize(operand->shape());
}
current_properties_[kBytesAccessedKey] = bytes_accessed;
@@ -121,6 +121,13 @@ Status HloCostAnalysis::HandleElementwiseOp(
}
}
+int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return 0;
+ }
+ return shape_size_(shape);
+}
+
Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
return HandleElementwiseOp(hlo);
}
@@ -181,21 +188,21 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
}
Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
- current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2;
+ current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicSlice(
const HloInstruction* dynamic_slice) {
current_properties_[kBytesAccessedKey] =
- shape_size_(dynamic_slice->shape()) * 2;
+ GetShapeSize(dynamic_slice->shape()) * 2;
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicUpdateSlice(
const HloInstruction* dynamic_update_slice) {
current_properties_[kBytesAccessedKey] =
- shape_size_(dynamic_update_slice->operand(1)->shape()) * 2;
+ GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2;
return Status::OK();
}
@@ -204,7 +211,7 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
// through them). The memory touched is then only the size of the output
// index table of the tuple.
- current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape());
+ current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
return Status::OK();
}
@@ -526,12 +533,12 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
// TODO(b/33004697): Compute correct cost here, taking the actual number of
// replicas into account.
double flops = 0.0;
- ShapeUtil::ForEachSubshape(
- crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsArray(subshape)) {
- flops += ShapeUtil::ElementsIn(subshape);
- }
- });
+ ShapeUtil::ForEachSubshape(crs->shape(),
+ [&](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsArray(subshape)) {
+ flops += ShapeUtil::ElementsIn(subshape);
+ }
+ });
current_properties_[kFlopsKey] = flops;
return Status::OK();
}
@@ -546,15 +553,9 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
}
Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
- // Compute the properties of the fused expression and attribute them to the
- // fusion node. Use a dummy shape_size to avoid any errors from trying to
- // calculate the size of a shape that does not have a layout, since nodes
- // inside fusion nodes do not necessarily have a layout assigned.
- ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; };
TF_ASSIGN_OR_RETURN(
current_properties_,
- ProcessSubcomputation(fusion->fused_instructions_computation(),
- &shape_size));
+ ProcessSubcomputation(fusion->fused_instructions_computation()));
// Fusion nodes that produce a tuple also produce the entries in the tuple.
// Ignore the memory accessed inside fused ops, since fusion is supposed to
@@ -563,11 +564,11 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
ShapeUtil::ForEachSubshape(
fusion->shape(),
[this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
- current_properties_[kBytesAccessedKey] += shape_size_(subshape);
+ current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
});
for (const HloInstruction* operand : fusion->operands()) {
- current_properties_[kBytesAccessedKey] += shape_size_(operand->shape());
+ current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape());
}
return Status::OK();
@@ -648,6 +649,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
return Status::OK();
}
+Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
+ // TODO(b/32945756): Compute the properties of the sub-computation.
+ return Status::OK();
+}
+
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
return Status::OK();
}
@@ -685,11 +691,8 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
}
StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
- HloComputation* computation, const ShapeSizeFunction* shape_size) {
- if (shape_size == nullptr) {
- shape_size = &shape_size_;
- }
- HloCostAnalysis visitor(*shape_size, per_second_rates_);
+ HloComputation* computation) {
+ HloCostAnalysis visitor(shape_size_, per_second_rates_);
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.properties();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 82d650dc7b..0a79c92f4a 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -104,6 +104,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
Status HandleGather(const HloInstruction* gather) override;
+ Status HandleScatter(const HloInstruction* scatter) override;
Status FinishVisit(const HloInstruction* root) override;
Status Preprocess(const HloInstruction* hlo) override;
@@ -149,11 +150,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
const Properties& per_second_rates);
// Returns the properties computed from visiting the computation rooted at the
- // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null,
- // otherwise uses shape_size_.
- StatusOr<Properties> ProcessSubcomputation(
- HloComputation* computation,
- const ShapeSizeFunction* shape_size = nullptr);
+ // given hlo.
+ StatusOr<Properties> ProcessSubcomputation(HloComputation* computation);
// Utility function to handle all element-wise operations.
Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
@@ -170,6 +168,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
const HloToProperties& hlo_to_properties);
+ // Decorates shape_size_ by returning 0 immediately if the shape does not have
+ // a layout.
+ int64 GetShapeSize(const Shape& shape) const;
+
// Function which computes the size of the top-level of a given shape (not
// including nested elements, if any). If null then bytes_accessed methods
// return an error.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 2ec31a9148..4755c4a0cf 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2365,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
- auto make_cond = [this, &data_shape]() {
+ auto make_cond = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
@@ -2374,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
- auto make_body = [this, &data_shape]() {
+ auto make_body = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index d5b4be7e12..d1ee4a180b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1481,8 +1481,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
ShapeUtil::Rank(arg->shape()) - dimensions.size());
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
- /*arg=*/arg->shape(),
- /*init_value=*/init_value->shape(),
+ {&arg->shape(), &init_value->shape()},
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index fd5085bed2..7e5866a356 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1019,6 +1019,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kWhite;
}
return kGreen;
+ case HloOpcode::kScatter:
+ // Do not de-emphasize Scatter, since it involves significant work.
case HloOpcode::kCopy:
// Emphasize copy nodes, which are either physical transposes (and thus
// significant), or copies of read-only buffers (and thus dead weight).
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8b9bdd2f46..7591b99204 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -404,6 +404,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
*gather_dimension_numbers, gather_window_bounds);
break;
}
+ case HloOpcode::kScatter: {
+ TF_RET_CHECK(proto.operand_ids_size() == 3)
+ << "Scatter instruction should have 3 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_scatter_dimension_numbers())
+ << "Scatter instruction should have ScatterDimensionNumbers set.";
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Scatter instruction should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
+ auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>(
+ proto.scatter_dimension_numbers());
+ instruction =
+ CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
+ computations(0), *scatter_dimension_numbers);
+ break;
+ }
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -812,11 +828,25 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- return MakeUnique<HloReduceInstruction>(
- shape, arg, init_value, dimensions_to_reduce, reduce_computation);
+ auto instruction = WrapUnique(new HloReduceInstruction(
+ shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
+ return std::move(instruction);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation) {
+ std::vector<HloInstruction*> all_args;
+ all_args.reserve(operands.size() * 2);
+ all_args.insert(all_args.end(), operands.begin(), operands.end());
+ all_args.insert(all_args.end(), init_values.begin(), init_values.end());
+ return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce,
+ reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
@@ -1062,6 +1092,16 @@ bool HloInstruction::HasSideEffect() const {
gather_dim_numbers, window_bounds);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers) {
+ return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices,
+ updates, update_computation,
+ scatter_dim_numbers);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
@@ -1124,6 +1164,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
+ case HloOpcode::kScatter:
case HloOpcode::kIota:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
@@ -1587,6 +1628,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
+ case HloOpcode::kScatter:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -1693,6 +1735,7 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@@ -1711,6 +1754,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
@@ -1977,7 +2021,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
- opcode() == HloOpcode::kCrossReplicaSum) {
+ opcode() == HloOpcode::kCrossReplicaSum ||
+ opcode() == HloOpcode::kScatter) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
@@ -2013,6 +2058,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kScatter:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
@@ -2311,6 +2357,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSendDone(this);
case HloOpcode::kGather:
return visitor->HandleGather(this);
+ case HloOpcode::kScatter:
+ return visitor->HandleScatter(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
@@ -3171,4 +3219,9 @@ tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
return Cast<HloGatherInstruction>(this)->gather_window_bounds();
}
+const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
+ const {
+ return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 70441b879d..e722086732 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -541,17 +541,34 @@ class HloInstruction {
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
- // is applied successively to every element in operand. That is, if f is the
- // function to apply (which either takes 2 [accumulator, value] or 3
- // [accumulator, index, value] arguments) and init is a reduction operator
- // specified initial value (for example, 0 for addition), then this operation
- // will compute:
- // f(f(init, [index0], value0), [index1], value1), ...)
+ // is applied successively to every element in operand. For example, let f be
+ // the function to apply, which takes 2 arguments, an accumulator and the
+ // current value. Let init be an initial value (which is normally chosen to be
+ // the identity element for f, e.g. 0 if f is addition).
+ // Then the reduce HLO will compute:
+ // f(f(init, value0), value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
+ // A more general, multiple-argument version of the above.
+ // The function to apply, f, now takes N arguments:
+ // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
+ // init_valueN], and returns an N-tuple. The performed computation is (for
+ // commutative and associative f operators) equivalent to:
+ //
+ // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0)
+ // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
+ // ..., inputN.value1)
+ // ...
+ // TODO(b/112040122): Add support to this in HLO passes and in backends.
+ static std::unique_ptr<HloInstruction> CreateReduce(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<HloInstruction*> init_values,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
+
// Creates a reduce-window instruction, where the computation (given
// by the handle) is applied window-wise at each valid window
// position in the operand.
@@ -644,6 +661,12 @@ class HloInstruction {
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ static std::unique_ptr<HloInstruction> CreateScatter(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+
// Creates a kDomain instruction which delimits an HLO domain which have
// the provided user and operand side metadata.
static std::unique_ptr<HloInstruction> CreateDomain(
@@ -1014,9 +1037,7 @@ class HloInstruction {
if (sharding_ == nullptr) {
return tensorflow::gtl::optional<int64>();
}
- auto device = sharding_->UniqueDevice();
- return device.ok() ? device.ValueOrDie()
- : tensorflow::gtl::optional<int64>();
+ return sharding_->UniqueDevice();
}
// Sets the sharding of this operator. Should only be called by HloModule or
// HloComputation methods.
@@ -1454,6 +1475,9 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_window_bounds.
tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+ // Delegates to HloScatterInstruction::scatter_dimension_numbers().
+ const ScatterDimensionNumbers& scatter_dimension_numbers() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index b75a2bd34b..8a694dde80 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1425,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
"index_vector_dim=2, window_bounds={30,29,28,27,26}");
}
+TEST_F(HloInstructionTest, StringifyScatter) {
+ Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ Shape scatter_indices_tensor_shape =
+ ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
+ Shape scatter_updates_shape =
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
+
+ HloComputation::Builder builder("Scatter");
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
+ HloInstruction* scatter_indices =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, scatter_indices_tensor_shape, "scatter_indices"));
+ HloInstruction* scatter_updates =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 2, scatter_updates_shape, "scatter_updates"));
+
+ HloComputation::Builder update_builder("Scatter.update");
+ update_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
+ update_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
+
+ auto module = CreateNewModule();
+ auto* update_computation =
+ module->AddEmbeddedComputation(update_builder.Build());
+
+ HloInstruction* scatter_instruction =
+ builder.AddInstruction(HloInstruction::CreateScatter(
+ input_tensor_shape, input, scatter_indices, scatter_updates,
+ update_computation,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2)));
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(
+ scatter_instruction->ToString(),
+ "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
+ "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
+ "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
+ "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
+ "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
+ "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
+ "to_apply=%Scatter.update");
+}
+
TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index df26a2c744..1d71a74c40 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -438,13 +438,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl(
}
HloReduceInstruction::HloReduceInstruction(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduce, shape),
dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
- AppendOperand(arg);
- AppendOperand(init_value);
+ for (HloInstruction* arg : args) {
+ AppendOperand(arg);
+ }
AppendComputation(reduce_computation);
}
@@ -477,8 +478,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
- return MakeUnique<HloReduceInstruction>(
- shape, new_operands[0], new_operands[1], dimensions(), to_apply());
+ return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(),
+ to_apply());
}
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
@@ -2015,4 +2016,91 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
gather_window_bounds());
}
+HloScatterInstruction::HloScatterInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers)
+ : HloInstruction(HloOpcode::kScatter, shape) {
+ AppendOperand(operand);
+ AppendOperand(scatter_indices);
+ AppendOperand(updates);
+ AppendComputation(update_computation);
+ scatter_dimension_numbers_ =
+ MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
+}
+
+string HloScatterInstruction::ScatterDimensionNumbersToString() const {
+ string update_window_dims =
+ StrCat("update_window_dims={",
+ Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
+ string inserted_window_dims = StrCat(
+ "inserted_window_dims={",
+ Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
+ string scatter_dims_to_operand_dims = StrCat(
+ "scatter_dims_to_operand_dims={",
+ Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
+ "}");
+ string index_vector_dim = StrCat(
+ "index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
+
+ return Join<std::initializer_list<string>>(
+ {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
+ index_vector_dim},
+ ", ");
+}
+
+/* static */ ScatterDimensionNumbers
+HloScatterInstruction::MakeScatterDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> update_window_dims,
+ tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
+ tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ int64 index_vector_dim) {
+ ScatterDimensionNumbers scatter_dim_numbers;
+ for (int64 update_window_dim : update_window_dims) {
+ scatter_dim_numbers.add_update_window_dims(update_window_dim);
+ }
+ for (int64 inserted_window_dim : inserted_window_dims) {
+ scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
+ }
+ for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
+ scatter_dim_numbers.add_scatter_dims_to_operand_dims(
+ scatter_dim_to_operand_dim);
+ }
+ scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
+ return scatter_dim_numbers;
+}
+
+HloInstructionProto HloScatterInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
+ return proto;
+}
+
+std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {ScatterDimensionNumbersToString()};
+}
+
+bool HloScatterInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
+ return protobuf_util::ProtobufEquals(
+ scatter_dimension_numbers(),
+ casted_other.scatter_dimension_numbers()) &&
+ eq_computations(to_apply(), casted_other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 3);
+ return MakeUnique<HloScatterInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
+ scatter_dimension_numbers());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 132e767420..ac5a1ca080 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -331,7 +331,7 @@ class HloConcatenateInstruction : public HloInstruction {
class HloReduceInstruction : public HloInstruction {
public:
explicit HloReduceInstruction(
- const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> args,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Returns the dimension sizes or numbers associated with this instruction.
@@ -1198,6 +1198,45 @@ class HloGatherInstruction : public HloInstruction {
std::vector<int64> gather_window_bounds_;
};
+class HloScatterInstruction : public HloInstruction {
+ public:
+ explicit HloScatterInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+ const ScatterDimensionNumbers& scatter_dimension_numbers() const {
+ CHECK(scatter_dimension_numbers_ != nullptr);
+ return *scatter_dimension_numbers_;
+ }
+ // Returns the dump string of the scatter dimension numbers.
+ string ScatterDimensionNumbersToString() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Creates an instance of ScatterDimensionNumbers.
+ static ScatterDimensionNumbers MakeScatterDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> update_window_dims,
+ tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
+ tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ int64 index_vector_dim);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index f0d9fdbc8f..71b44507cc 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -299,9 +299,12 @@ TokKind HloLexer::LexNumberOrPattern() {
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
- tensorflow::strings::safe_strto64(
- StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
- return TokKind::kInt;
+ auto slice = StringPieceFromPointers(token_start_, current_ptr_);
+ if (tensorflow::strings::safe_strto64(slice, &int64_val_)) {
+ return TokKind::kInt;
+ }
+ LOG(ERROR) << "Failed to parse int literal: " << slice;
+ return TokKind::kError;
}
static LazyRE2 neg_inf = {"-inf"};
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 59e9a5a94a..88531b6f20 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -118,6 +118,7 @@ namespace xla {
V(kReverse, "reverse") \
V(kRng, "rng") \
V(kRoundNearestAfz, "round-nearest-afz") \
+ V(kScatter, "scatter") \
V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index d71d3c8170..93cc884e3a 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -865,18 +865,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReduce: {
+ auto loc = lexer_.GetLoc();
+
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
optional<std::vector<tensorflow::int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
+ if (operands.size() % 2) {
+ return Error(loc, StrCat("expects an even number of operands, but has ",
+ operands.size(), " operands"));
+ }
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
- shape, /*operand=*/operands[0], /*init_value=*/operands[1],
+ shape, /*operands=*/
+ tensorflow::gtl::ArraySlice<HloInstruction*>(operands, 0,
+ operands.size() / 2),
+ /*init_values=*/
+ tensorflow::gtl::ArraySlice<HloInstruction*>(
+ operands, operands.size() / 2, operands.size()),
*dimensions_to_reduce, *reduce_computation));
break;
}
@@ -1242,6 +1252,42 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
dim_numbers, *window_bounds));
break;
}
+ case HloOpcode::kScatter: {
+ optional<std::vector<tensorflow::int64>> update_window_dims;
+ attrs["update_window_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
+ optional<std::vector<tensorflow::int64>> inserted_window_dims;
+ attrs["inserted_window_dims"] = {
+ /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
+ optional<std::vector<tensorflow::int64>> scatter_dims_to_operand_dims;
+ attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
+ AttrTy::kBracedInt64List,
+ &scatter_dims_to_operand_dims};
+ optional<tensorflow::int64> index_vector_dim;
+ attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
+ &index_vector_dim};
+
+ optional<HloComputation*> update_computation;
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &update_computation};
+
+ if (!ParseOperands(&operands, /*expected_size=*/3) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+
+ ScatterDimensionNumbers dim_numbers =
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/*update_window_dims,
+ /*inserted_window_dims=*/*inserted_window_dims,
+ /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
+ /*index_vector_dim=*/*index_vector_dim);
+
+ instruction = builder->AddInstruction(HloInstruction::CreateScatter(
+ shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
+ /*updates=*/operands[2], *update_computation, dim_numbers));
+ break;
+ }
case HloOpcode::kDomain: {
DomainData domain;
attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
@@ -1590,6 +1636,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value,
"value ", value, " is out of range for literal's primitive type ",
PrimitiveType_Name(literal->shape().element_type())));
}
+ } else if (std::is_unsigned<LiteralNativeT>::value) {
+ CHECK((std::is_same<ParsedElemT, tensorflow::int64>::value ||
+ std::is_same<ParsedElemT, bool>::value))
+ << "Unimplemented checking for ParsedElemT";
+
+ ParsedElemT upper_bound;
+ if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) {
+ upper_bound = std::numeric_limits<ParsedElemT>::max();
+ } else {
+ upper_bound =
+ static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max());
+ }
+ if (value > upper_bound || value < 0) {
+ // Value is out of range for LiteralNativeT.
+ return TokenError(StrCat(
+ "value ", value, " is out of range for literal's primitive type ",
+ PrimitiveType_Name(literal->shape().element_type())));
+ }
} else if (value > static_cast<ParsedElemT>(
std::numeric_limits<LiteralNativeT>::max()) ||
value < static_cast<ParsedElemT>(
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 1c08c51220..7344679bb6 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -760,6 +760,46 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5
)"
},
+{
+"scatter",
+R"(HloModule StringifyScatter
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
+ %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
+ %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
+ %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
+ ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
+}
+
+)"
+},
+{
+ "ConstantUnsignedNoUnderflow",
+ R"(HloModule ConstantUnsignedNoUnderflow_module
+
+ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
+ ROOT %constant = u64[] constant(1)
+}
+
+)"
+},
+
+{
+ "ConstantUnsignedNoOverflow",
+ R"(HloModule ConstantUnsignedNoOverflow_module
+
+ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
+ ROOT %constant = u64[] constant(9223372036854775807)
+}
+
+)"
+},
});
// clang-format on
}
@@ -805,6 +845,32 @@ ENTRY ReduceR3ToR2.v3 {
)"
},
+// tuple reduce
+{
+"TupleReduce",
+R"(HloModule TupleReduce
+
+max_argmax {
+ value = f32[] parameter(2)
+ prev_max = f32[] parameter(0)
+ is_next_larger = pred[] greater-than-or-equal-to(value, prev_max)
+ max = f32[] select(is_next_larger, value, prev_max)
+ index = s32[] parameter(3)
+ prev_argmax = s32[] parameter(1)
+ argmax = s32[] select(is_next_larger, index, prev_argmax)
+ ROOT pair = (f32[], s32[]) tuple(max, argmax)
+}
+
+ENTRY reduce_entry {
+ values = f32[1024]{0} parameter(0)
+ indices = f32[1024]{0} parameter(1)
+ init_value = f32[] constant(-inf)
+ init_index = s32[] constant(-1)
+ ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
+}
+
+)"
+},
// infeed/outfeed
{
"InfeedOutfeed",
@@ -1224,6 +1290,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] {
"is out of range for literal's primitive type F16");
}
+TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
+ const string original = R"(
+ HloModule ConstantUnsignedUnderflow_module
+ ENTRY %ConstantUnsignedUnderflow () -> u64[] {
+ ROOT %constant = u64[] constant(-1)
+ })";
+ auto result = ParseHloString(original);
+ EXPECT_NE(Status::OK(), result.status());
+ ExpectHasSubstr(result.status().error_message(),
+ "is out of range for literal's primitive type U64");
+}
+
+TEST_F(HloParserTest, ConstantUnsignedOverflow) {
+ const string original = R"(
+ HloModule ConstantUnsignedOverflow_module
+ ENTRY %ConstantUnsignedOverflow () -> u32[] {
+ ROOT %constant = u32[] constant(4294967296)
+ })";
+ auto result = ParseHloString(original);
+ EXPECT_NE(Status::OK(), result.status());
+ ExpectHasSubstr(result.status().error_message(),
+ "is out of range for literal's primitive type U32");
+}
+
+TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
+ const string original = R"(
+ HloModule ConstantUnsignedOverflow_module
+ ENTRY %ConstantUnsignedOverflow () -> u64[] {
+ ROOT %constant = u64[] constant(9223372036854775808)
+ })";
+ auto result = ParseHloString(original);
+ EXPECT_NE(Status::OK(), result.status());
+}
+
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module
diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h
index b3d0a07add..28194deb0e 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_fix.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_
+#include <algorithm>
+
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -34,9 +36,19 @@ class HloPassFix : public Pass {
StatusOr<bool> Run(HloModule* module) override {
bool changed = false;
bool changed_this_iteration = true;
+ int64 iteration_count = 0;
+ int64 limit =
+ std::max(static_cast<int64>(1000), module->instruction_count());
while (changed_this_iteration) {
TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
changed |= changed_this_iteration;
+ ++iteration_count;
+ if (iteration_count == limit) {
+ LOG(ERROR)
+ << "Unexpectedly number of iterations in HLO passes ("
+ << iteration_count
+ << ")\nIf compilation hangs here, please file a bug with XLA.";
+ }
}
return changed;
}
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index cf9ceed5b2..9ec983c2bc 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -282,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
ScheduleComputationsInModule(*module,
- [&TUPLE_SIZE](const BufferValue& buffer) {
+ [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(
buffer.shape(), TUPLE_SIZE);
},
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 393944c20f..6399f6ef3c 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -127,15 +127,15 @@ std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
if (IsTuple()) {
for (auto& tuple_element_sharding : tuple_elements()) {
auto unique_device = tuple_element_sharding.UniqueDevice();
- if (unique_device.ok()) {
- device_map[unique_device.ValueOrDie()] += 1;
+ if (unique_device) {
+ device_map[*unique_device] += 1;
}
}
element_count = tuple_elements().size();
} else {
auto unique_device = UniqueDevice();
- if (unique_device.ok()) {
- device_map[unique_device.ValueOrDie()] += 1;
+ if (unique_device) {
+ device_map[*unique_device] += 1;
}
}
if (count != nullptr) {
@@ -238,40 +238,31 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
return Tuple(ShapeTree<HloSharding>(shape, *this));
}
-StatusOr<int64> HloSharding::UniqueDevice() const {
+tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
- return tensorflow::errors::InvalidArgument(
- "UniqueDevice() called on empty tuple");
+ return tensorflow::gtl::nullopt;
}
- std::vector<StatusOr<int64>> results;
- std::transform(tuple_elements_.begin(), tuple_elements_.end(),
- std::back_inserter(results),
- [](const HloSharding& s) { return s.UniqueDevice(); });
- if (std::all_of(results.begin(), results.end(),
- [&](const StatusOr<int64>& s) {
- return s.ok() && results[0].ok() &&
- s.ValueOrDie() == results[0].ValueOrDie();
- })) {
- return results[0];
- } else {
- return tensorflow::errors::InvalidArgument(
- "Tuple did not contain a unique device");
+ tensorflow::gtl::optional<int64> unique_device;
+ for (auto& tuple_sharding : tuple_elements_) {
+ auto device = tuple_sharding.UniqueDevice();
+ if (!device || (unique_device && *device != *unique_device)) {
+ return tensorflow::gtl::nullopt;
+ }
+ unique_device = device;
}
+ return unique_device;
}
- if (!replicated_ && maximal_ && !IsTuple()) {
+ if (!replicated_ && maximal_) {
return static_cast<int64>(*tile_assignment_.begin());
}
- return tensorflow::errors::InvalidArgument(
- "UniqueDevice() called on sharding that executes on multiple devices");
+ return tensorflow::gtl::nullopt;
}
-bool HloSharding::HasUniqueDevice() const {
- if (IsTuple()) {
- return UniqueDevice().status().ok();
- } else {
- return !IsReplicated() && IsTileMaximal();
- }
+int64 HloSharding::GetUniqueDevice() const {
+ auto device = UniqueDevice();
+ CHECK(device) << "Sharding does not have a unique device: " << *this;
+ return *device;
}
Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 6f672b0f28..28575c0e75 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -158,12 +158,17 @@ class HloSharding {
// REQUIRES: !IsTuple()
std::vector<int64> TileLimitForDevice(int64 device) const;
- // Returns the single device this op operates on.
- // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
- StatusOr<int64> UniqueDevice() const;
+ // Returns the single device this op operates on. If the sharding does not
+ // span a single device, the return value will be empty.
+ // In order for a sharding to span a single device, every leaf sharding must
+ // be maximal and not replicated, and the used device must match.
+ tensorflow::gtl::optional<int64> UniqueDevice() const;
+
+ // Retrieves the unique device or fails with a CHECK.
+ int64 GetUniqueDevice() const;
// Returns true if this op only uses a single device.
- bool HasUniqueDevice() const;
+ bool HasUniqueDevice() const { return UniqueDevice().has_value(); }
// Returns the ShapeTree containing the shardings for each element of this
// tuple, if IsTuple, or a ShapeTree with a single element containing this
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 7baa927d0e..aebda562d3 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -51,7 +51,7 @@ TEST_F(HloShardingTest, Replicate) {
EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}),
/*num_devices=*/2));
- EXPECT_IS_NOT_OK(sharding.UniqueDevice());
+ EXPECT_FALSE(sharding.HasUniqueDevice());
}
TEST_F(HloShardingTest, DevicePlacement) {
@@ -60,7 +60,7 @@ TEST_F(HloShardingTest, DevicePlacement) {
EXPECT_TRUE(sharding.IsTileMaximal());
EXPECT_FALSE(sharding.UsesDevice(0));
EXPECT_TRUE(sharding.UsesDevice(5));
- EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie());
+ EXPECT_EQ(5, sharding.GetUniqueDevice());
HloSharding other = HloSharding::Replicate();
EXPECT_NE(other, sharding);
@@ -123,7 +123,7 @@ TEST_F(HloShardingTest, Tile) {
EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector<int64>{2, 0}));
EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector<int64>{2, 3}));
- EXPECT_IS_NOT_OK(sharding.UniqueDevice());
+ EXPECT_FALSE(sharding.HasUniqueDevice());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 48f676db85..b78bfa0cdf 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction(
}
};
string node_name;
- if (debug_options_.xla_hlo_tfgraph_device_scopes() &&
- instruction->has_sharding() &&
- instruction->sharding().HasUniqueDevice()) {
- node_name = StrCat(
- "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie());
+ if (debug_options_.xla_hlo_tfgraph_device_scopes()) {
+ auto device = instruction->sharding_unique_device();
+ if (device) {
+ node_name = StrCat("dev", *device);
+ }
}
// If an instruction is fused, put it in the subgraph of the fusion;
// otherwise, put it in the computation subgraph.
@@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
- if (instruction->has_sharding() &&
- instruction->sharding().HasUniqueDevice()) {
- TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice());
- node_def->set_device(GetDeviceName(device));
+
+ auto device = instruction->sharding_unique_device();
+ if (device) {
+ node_def->set_device(GetDeviceName(*device));
}
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 4e3c9df3a0..7fd99fc930 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out,
string InstructionValueSet::ToString() const {
string out =
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
- ForEachElement([this, &out](const ShapeIndex& index,
- const HloValueSet& value_set) {
+ ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
});
return out;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 25fa319faf..1a8c206aaf 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -224,10 +224,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
+ if (!ShapeUtil::IsArray(reduce->shape())) {
+ return InvalidArgument("Variadic reduce is not supported.");
+ }
return CheckShape(
reduce,
ShapeInference::InferReduceShape(
- reduce->operand(0)->shape(), reduce->operand(1)->shape(),
+ {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()},
reduce->dimensions(), reduce->to_apply()->ComputeProgramShape()));
}
@@ -510,6 +513,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
+Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
+ return CheckShape(
+ scatter, ShapeInference::InferScatterShape(
+ scatter->operand(0)->shape(), scatter->operand(1)->shape(),
+ scatter->operand(2)->shape(),
+ scatter->to_apply()->ComputeProgramShape(),
+ scatter->scatter_dimension_numbers()));
+}
+
Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : token->operands()) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 79f7aa9f4c..7feddaeabf 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -83,6 +83,7 @@ class ShapeVerifier : public DfsHloVisitor {
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* token) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
index d7458c338e..bb5b40a8a8 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
@@ -36,7 +36,8 @@ string HumanReadableProfileBuilder::ToString() const {
computation_name_.c_str(),
HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str());
- auto print_op = [&](const OpInfo& op) {
+ int64 cumulative_cycles = 0;
+ auto print_op = [&](const OpInfo& op, bool is_total = false) {
// Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that
// were expected to be free and are actually free -- things like (on most
// backends) kParameter or kConstant HLOs. There's no need to clutter the
@@ -59,27 +60,44 @@ string HumanReadableProfileBuilder::ToString() const {
}
}
+ double cumulative_cycles_percent = 0;
double cycles_percent = 0;
+ if (!is_total) {
+ cumulative_cycles += op.cycles;
+ }
if (total_cycles_ > 0) {
cycles_percent = op.cycles / static_cast<double>(total_cycles_) * 100;
+ cumulative_cycles_percent =
+ cumulative_cycles / static_cast<double>(total_cycles_) * 100;
+ }
+
+ string cycles_percent_str;
+ if (is_total) {
+ // Leaving off the two trailing decimal points of "100.%" lets us save two
+ // columns in the output.
+ cycles_percent_str = "100.% 100Σ";
+ } else {
+ cycles_percent_str =
+ Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent);
}
double nsecs = op.cycles / clock_rate_ghz_;
- Appendf(&s,
- "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s "
- ":: %18s :: %14s :: %16s :: %s\n",
- op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles),
- op.optimal_seconds < 0
- ? ""
- : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(),
- op.flop_count <= 0
- ? ""
- : HumanReadableNumFlops(op.flop_count, nsecs).c_str(),
- op.transcendental_count <= 0 ? ""
- : HumanReadableNumTranscendentalOps(
- op.transcendental_count, nsecs)
- .c_str(),
- bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str());
+ Appendf(
+ &s,
+ "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: "
+ "%16s :: %s\n",
+ op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles),
+ op.optimal_seconds < 0
+ ? ""
+ : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(),
+ op.flop_count <= 0
+ ? ""
+ : HumanReadableNumFlops(op.flop_count, nsecs).c_str(),
+ op.transcendental_count <= 0
+ ? ""
+ : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs)
+ .c_str(),
+ bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str());
};
float optimal_seconds_sum = 0.0;
@@ -98,7 +116,8 @@ string HumanReadableProfileBuilder::ToString() const {
VLOG(1) << "Total floating point ops: " << total_flops;
print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops,
- total_transcendentals, total_bytes, optimal_seconds_sum});
+ total_transcendentals, total_bytes, optimal_seconds_sum},
+ /*is_total=*/true);
// Sort ops in decreasing order of cycles, and print them.
std::vector<OpInfo> sorted_ops(op_infos_);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index af07370135..e2191aedb7 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -141,6 +141,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kReduceWindow:
case HloOpcode::kRemainder:
case HloOpcode::kRng:
+ case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 9705687b00..b5a9d6e8e7 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -874,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
// HostCompute module.
// Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location.
- if (!sharding.HasUniqueDevice() ||
- HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) {
+ auto device = sharding.UniqueDevice();
+ if (!device || HloSharding::IsReservedDevice(*device)) {
copy->set_sharding(sharding);
}
}
@@ -1228,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs(
const PointsToSet& points_to_set =
constraints->points_to_analysis().GetPointsToSet(instruction);
return points_to_set.ForEachElementWithStatus(
- [this, &shape_layout, constraints](
+ [&shape_layout, constraints](
const ShapeIndex& index,
const PointsToSet::BufferList& buffers) -> Status {
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
index 941d940684..fe5ec1cc66 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -56,12 +56,12 @@ ENTRY while3 {
)";
CompileAndVerifyIr(hlo_string, R"(
-; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
+; CHECK-LABEL: @body(i8* %retval
; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
-; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
-; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0
+; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0
; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ce070bc5b6..212db0643c 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -56,7 +56,6 @@ limitations under the License.
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrCat;
-using ::xla::source_map_util::InvalidParameterArgument;
namespace xla {
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 35df792b07..c888bbf144 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -58,66 +58,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
return Status::OK();
}
-Status VerifyReducerShape(const ProgramShape& reducer_shape,
- const Shape& init_value_shape,
- const PrimitiveType& input_element_type) {
- if (reducer_shape.parameters_size() != 2) {
- return InvalidArgument(
- "Reduction function must take 2 parameters, but "
+Status VerifyReducerShape(
+ const ProgramShape& reducer_shape,
+ tensorflow::gtl::ArraySlice<const Shape*> init_value_shapes,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_element_types,
+ int64 inputs) {
+ if (reducer_shape.parameters_size() != inputs * 2) {
+ return InvalidArgument(
+ "Reduction function must take %lld parameters, but "
"takes %d parameter(s).",
- reducer_shape.parameters_size());
+ inputs * 2, reducer_shape.parameters_size());
}
const Shape& accumulator_shape = reducer_shape.result();
- if (!ShapeUtil::IsArray(accumulator_shape) ||
- ShapeUtil::Rank(accumulator_shape) != 0) {
- return InvalidArgument(
- "Reduction function must produce a scalar but has shape: %s",
- ShapeUtil::HumanString(accumulator_shape).c_str());
- }
-
- // Check that the accumulator can be passed in as the first argument.
- // Note: comparing here and below with Compatible since we don't care about
- // layout in scalars - see b/26668201 for a longer-term vision.
- if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
+ std::vector<const Shape*> accumulator_subshapes;
+ if (ShapeUtil::IsArray(accumulator_shape)) {
+ if (inputs != 1) {
+ return InvalidArgument(
+ "Reduction function must produce a tuple with %lld elements, but "
+ "produces a scalar",
+ inputs);
+ }
+ accumulator_subshapes.push_back(&accumulator_shape);
+ } else if (ShapeUtil::IsTuple(accumulator_shape)) {
+ if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) {
+ return InvalidArgument(
+ "Reduction function must produce a tuple with %lld elements, but has "
+ "%lld elements",
+ inputs, ShapeUtil::TupleElementCount(accumulator_shape));
+ }
+ for (const Shape& element_shape : accumulator_shape.tuple_shapes()) {
+ accumulator_subshapes.push_back(&element_shape);
+ }
+ } else {
return InvalidArgument(
- "Reduction function's first parameter shape differs from the "
- "result shape: %s vs %s",
- ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
+ "Reduction function must produce a scalar or tuple of scalars, but has "
+ "shape: %s",
ShapeUtil::HumanString(accumulator_shape).c_str());
}
- // Check that init_value's shape is suitable for reducer_shape.
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
- init_value_shape)) {
- return InvalidArgument(
- "Reduction function's accumulator shape differs from the "
- "init_value shape: %s vs %s",
- ShapeUtil::HumanString(accumulator_shape).c_str(),
- ShapeUtil::HumanString(init_value_shape).c_str());
- }
-
- // Check that the inputs can be passed in as the second argument.
- const Shape& input_element_shape =
- ShapeUtil::MakeShape(input_element_type, {});
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
- reducer_shape.parameters(1))) {
- return InvalidArgument(
- "Reduction function's second parameter shape differs from the "
- "input type element type: %s vs %s",
- ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(input_element_shape).c_str());
+ for (const Shape* element_shape : accumulator_subshapes) {
+ if (ShapeUtil::Rank(*element_shape) != 0) {
+ return InvalidArgument(
+ "Reduction function must return a scalar or tuple of scalars but "
+ "returns shape: %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
}
- // Currently the accumulator and inputs must be the same type,
- // though that restriction could be relaxed.
- if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
- reducer_shape.parameters(1))) {
- return InvalidArgument(
- "Reduction function's second parameter shape must "
- "match the result shape, but got %s vs %s.",
- ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
- ShapeUtil::HumanString(accumulator_shape).c_str());
+ for (int64 i = 0; i < inputs; ++i) {
+ // Check that the accumulator can be passed in as the first argument.
+ // Note: comparing here and below with Compatible since we don't care about
+ // layout in scalars - see b/26668201 for a longer-term vision.
+ if (!ShapeUtil::Compatible(*accumulator_subshapes[i],
+ reducer_shape.parameters(i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape differs from the "
+ "result shape: %s vs %s",
+ i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ }
+ // Check that init_value's shapes are suitable for reducer_shape.
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i],
+ *init_value_shapes[i])) {
+ return InvalidArgument(
+ "Reduction function's accumulator shape at index %lld differs from "
+ "the init_value shape: %s vs %s",
+ i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(),
+ ShapeUtil::HumanString(*init_value_shapes[i]).c_str());
+ }
+ // Check that the inputs can be passed in as the non-accumulator arguments.
+ const Shape input_element_shape =
+ ShapeUtil::MakeShape(input_element_types[i], {});
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ input_element_shape, reducer_shape.parameters(inputs + i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape differs from the "
+ "input type element type: %s vs %s",
+ inputs + i,
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
+ ShapeUtil::HumanString(input_element_shape).c_str());
+ }
+ // Check that the accumulator and inputs to the reducer function match.
+ // If the accumulator is scalar, it must have the same type as the inputs
+ // (up to fp precision). If it is a tuple, then the k-th element of the
+ // tuple must have the same type as the K-th input (again, up to fp
+ // precision.)
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) {
+ return InvalidArgument(
+ "Reduction function's %lld-th parameter shape must "
+ "match the result shape, but got %s vs %s.",
+ inputs + i,
+ ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(),
+ ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str());
+ }
}
return Status::OK();
@@ -1745,10 +1780,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
- const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
const ProgramShape& to_apply) {
- // Check that the dimension to reduce are in-bounds for the given shape.
+ if (arg_shapes.empty()) {
+ return InvalidArgument("Reduce must have at least 2 arguments, has 0");
+ }
+ if (arg_shapes.size() % 2) {
+ return InvalidArgument(
+ "Reduce must have an even number of arguments, has %lu",
+ arg_shapes.size());
+ }
+ int64 num_reduced_args = arg_shapes.size() / 2;
+
+ tensorflow::gtl::ArraySlice<const Shape*> reduced_args(arg_shapes, 0,
+ num_reduced_args);
+ // Check that all of the reduced tensors have the same dimensions. The element
+ // types may be different.
+ for (int64 i = 1; i < num_reduced_args; ++i) {
+ if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) {
+ return InvalidArgument(
+ "All reduced tensors must have the sime dimension. Tensor 0 has "
+ "shape %s, Tensor %lld has shape %s",
+ ShapeUtil::HumanString(*reduced_args[0]).c_str(), i,
+ ShapeUtil::HumanString(*reduced_args[i]).c_str());
+ }
+ }
+
+ // Check that the dimensions to reduce are in-bounds for the given shape.
+ // We've already verified all reduced tensors have the same dimensions, so it
+ // doesn't matter which one we choose.
+ const Shape& arg = *reduced_args[0];
for (int64 dimension : dimensions_to_reduce) {
if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
return InvalidArgument(
@@ -1756,8 +1818,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(arg).c_str());
}
}
- TF_RETURN_IF_ERROR(
- VerifyReducerShape(to_apply, init_value, arg.element_type()));
+
+ tensorflow::gtl::ArraySlice<const Shape*> init_values(
+ arg_shapes, num_reduced_args, arg_shapes.size());
+ std::vector<PrimitiveType> element_types;
+ for (const Shape* arg : reduced_args) {
+ element_types.push_back(arg->element_type());
+ }
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types,
+ num_reduced_args));
std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
dimensions_to_reduce.end());
@@ -1768,15 +1837,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
- return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
+ if (ShapeUtil::IsScalar(to_apply.result())) {
+ return ShapeUtil::MakeShape(to_apply.result().element_type(),
+ new_dimensions);
+ } else {
+ std::vector<Shape> result_subshapes;
+ for (const Shape& subshape : to_apply.result().tuple_shapes()) {
+ result_subshapes.push_back(
+ ShapeUtil::MakeShape(subshape.element_type(), new_dimensions));
+ }
+ return ShapeUtil::MakeTupleShape(result_subshapes);
+ }
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
- TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
- operand_shape.element_type()));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
+ {operand_shape.element_type()},
+ /*inputs=*/1));
return InferWindowOutputShape(operand_shape, window,
init_value_shape.element_type(),
/*allow_negative_padding=*/false);
@@ -1821,8 +1901,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
// Check if the scatter function has a proper shape as a reduction.
- TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
- source_shape.element_type()));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape},
+ {source_shape.element_type()},
+ /*inputs=*/1));
// Check if the result shape of window operation matches the source shape.
TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
@@ -2568,4 +2649,194 @@ static Status ValidateGatherDimensionNumbers(
return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
}
+namespace {
+
+Status ValidateScatterDimensionNumbers(
+ const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
+ const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
+ // Validate update_window_dims in ScatterDimensionNumbers.
+ if (!c_is_sorted(dim_numbers.update_window_dims())) {
+ return InvalidArgument(
+ "update_window_dims in scatter op must be sorted; got: %s.",
+ Join(dim_numbers.update_window_dims(), ", ").c_str());
+ }
+ if (c_adjacent_find(dim_numbers.update_window_dims()) !=
+ dim_numbers.update_window_dims().end()) {
+ return InvalidArgument(
+ "update_window_dims in scatter op must not repeat; got: %s.",
+ Join(dim_numbers.update_window_dims(), ", ").c_str());
+ }
+ const int64 updates_rank = ShapeUtil::Rank(updates_shape);
+ for (int64 window_dim : dim_numbers.update_window_dims()) {
+ if (window_dim < 0 || window_dim >= updates_rank) {
+ return InvalidArgument(
+ "Invalid update_window_dims set in scatter op; valid range is [0, "
+ "%lld). got: %lld.",
+ updates_rank, window_dim);
+ }
+ }
+
+ // Validate inserted_window_dims in ScatterDimensionNumbers.
+ if (!c_is_sorted(dim_numbers.inserted_window_dims())) {
+ return InvalidArgument(
+ "inserted_window_dims in scatter op must be sorted; got: %s.",
+ Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ }
+ if (c_adjacent_find(dim_numbers.inserted_window_dims()) !=
+ dim_numbers.inserted_window_dims().end()) {
+ return InvalidArgument(
+ "inserted_window_dims in scatter op must not repeat; got: %s.",
+ Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ }
+ for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
+ if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid inserted_window_dims set in scatter op; valid range is [0, "
+ "%d), got: %lld.",
+ operand_shape.dimensions_size(), inserted_dim);
+ }
+ }
+
+ // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
+ if (dim_numbers.scatter_dims_to_operand_dims_size() !=
+ scatter_indices_shape[dim_numbers.index_vector_dim()]) {
+ return InvalidArgument(
+ "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
+ "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. "
+ "These two numbers must be equal.",
+ dim_numbers.scatter_dims_to_operand_dims_size(),
+ dim_numbers.index_vector_dim(),
+ scatter_indices_shape[dim_numbers.index_vector_dim()]);
+ }
+ for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
+ int64 scatter_dim_to_operand_dim =
+ dim_numbers.scatter_dims_to_operand_dims(i);
+ if (scatter_dim_to_operand_dim < 0 ||
+ scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
+ "got: %d->%lld.",
+ operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
+ }
+ }
+ std::vector<int64> sorted_scatter_dims_to_operand_dims(
+ dim_numbers.scatter_dims_to_operand_dims().begin(),
+ dim_numbers.scatter_dims_to_operand_dims().end());
+ c_sort(sorted_scatter_dims_to_operand_dims);
+ if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
+ sorted_scatter_dims_to_operand_dims.end()) {
+ return InvalidArgument(
+ "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
+ "got: %s.",
+ Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str());
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+/*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
+ const Shape& operand_shape, const Shape& scatter_indices_shape,
+ const Shape& updates_shape, const ProgramShape& to_apply_shape,
+ const ScatterDimensionNumbers& scatter_dim_numbers) {
+ TF_RETURN_IF_ERROR(
+ ExpectArray(operand_shape, "operand tensor of scatter op"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
+ TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
+
+ if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
+ return InvalidArgument(
+ "Scatter indices parameter must be an integral tensor; got %s.",
+ ShapeUtil::HumanString(scatter_indices_shape).c_str());
+ }
+
+ if (scatter_indices_shape.dimensions_size() <
+ scatter_dim_numbers.index_vector_dim() ||
+ scatter_dim_numbers.index_vector_dim() < 0) {
+ return InvalidArgument(
+ "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
+ " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
+ "is %lld.",
+ scatter_indices_shape.dimensions_size(),
+ scatter_dim_numbers.index_vector_dim());
+ }
+
+ // Check if the update computation has a proper shape as a reduction.
+ const Shape init_value_shape =
+ ShapeUtil::MakeShape(operand_shape.element_type(), {});
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape},
+ {updates_shape.element_type()},
+ /*inputs=*/1));
+
+ std::vector<int64> expanded_scatter_indices_shape =
+ ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));
+ if (expanded_scatter_indices_shape.size() ==
+ scatter_dim_numbers.index_vector_dim()) {
+ expanded_scatter_indices_shape.push_back(1);
+ }
+
+ int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
+ scatter_dim_numbers.update_window_dims_size();
+ if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) {
+ return InvalidArgument("Updates tensor must be of rank %lld; got %lld.",
+ expected_updates_rank,
+ ShapeUtil::Rank(updates_shape));
+ }
+
+ TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
+ operand_shape, expanded_scatter_indices_shape, updates_shape,
+ scatter_dim_numbers));
+
+ int64 inserted_dims_seen = 0;
+ std::vector<int64> max_update_window_bounds;
+ for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
+ if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
+ scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
+ ++inserted_dims_seen;
+ } else {
+ max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ }
+ }
+ for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
+ auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
+ if (updates_shape.dimensions(update_window_dim) >
+ max_update_window_bounds[i]) {
+ return InvalidArgument(
+ "Bounds of the window dimensions of updates must not exceed the "
+ "bounds of the corresponding dimensions of operand. For dimension "
+ "%lld, updates bound is %lld, operand bound is %lld.",
+ update_window_dim, updates_shape.dimensions(update_window_dim),
+ max_update_window_bounds[i]);
+ }
+ }
+
+ int64 scatter_dims_seen = 0;
+ for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
+ bool is_update_window_dim =
+ c_binary_search(scatter_dim_numbers.update_window_dims(), i);
+ if (is_update_window_dim) {
+ continue;
+ }
+ if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
+ ++scatter_dims_seen;
+ }
+ if (updates_shape.dimensions(i) !=
+ expanded_scatter_indices_shape[scatter_dims_seen]) {
+ return InvalidArgument(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices. For "
+ "scatter dimension %lld, updates bound is %lld, scatter_indices "
+ "bound is %lld.",
+ i, updates_shape.dimensions(i),
+ expanded_scatter_indices_shape[scatter_dims_seen]);
+ }
+ ++scatter_dims_seen;
+ }
+
+ return operand_shape;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 1a5684e3c3..33da323b3d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -131,7 +131,7 @@ class ShapeInference {
// index as the leading parameter, and the program shape should match
// accordingly (or an error will result).
static StatusOr<Shape> InferReduceShape(
- const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
const ProgramShape& to_apply);
@@ -268,6 +268,14 @@ class ShapeInference {
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ // Helper that validates the given input shape, scatter indices shape, updates
+ // shape, and scatter dimension numbers that constitute a scatter operation,
+ // and returns the result shape of the scatter operation.
+ static StatusOr<Shape> InferScatterShape(
+ const Shape& operand_shape, const Shape& scatter_indices_shape,
+ const Shape& updates_shape, const ProgramShape& to_apply_shape,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+
private:
// Helper that infers the shape produced by performing an element-wise binary
// operation with the given LHS and RHS shapes.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6046d50c6d..a73fa181cd 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -63,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest {
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
auto inferred_status = ShapeInference::InferReduceShape(
- arg, f32_, dimensions_to_reduce, to_apply);
+ {&arg, &f32_}, dimensions_to_reduce, to_apply);
EXPECT_IS_OK(inferred_status.status());
EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
inferred_status.ValueOrDie()));
@@ -703,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
/*dimensions_to_reduce=*/{0, 1, 2});
}
+TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_IS_OK(inferred_status.status());
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply =
+ ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
+ ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr(
+ "parameter shape differs from the result shape: s32[] vs f32[]"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must have at least 2 arguments, has 0"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply =
+ ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(
+ inferred_status.status().error_message(),
+ HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
+ Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
+ Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
+ auto inferred_status = ShapeInference::InferReduceShape(
+ {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_THAT(inferred_status.status().error_message(),
+ HasSubstr("accumulator shape at index 0 differs from the "
+ "init_value shape: s32[] vs f32[]"));
+}
+
TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status = ShapeInference::InferReduceShape(
- ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4},
- to_apply);
+ {&arg_shape, &f32_},
+ /*dimensions_to_reduce=*/{3, 4}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
HasSubstr("out-of-bounds dimension"));
@@ -715,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status =
- ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ ShapeInference::InferReduceShape({&arg_shape, &f32_},
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
@@ -725,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
+ Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
auto inferred_status =
- ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ ShapeInference::InferReduceShape({&arg_shape, &f32_},
/*dimensions_to_reduce=*/{0}, to_apply);
EXPECT_FALSE(inferred_status.ok());
EXPECT_THAT(inferred_status.status().error_message(),
- HasSubstr("first parameter shape differs"));
+ HasSubstr("0-th parameter shape differs"));
}
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
@@ -1536,7 +1626,7 @@ TEST_F(ShapeInferenceTest, BadSort) {
<< statusor.status();
}
-class GatherShapeInferenceTest : public ShapeInferenceTest {
+class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
@@ -1553,9 +1643,13 @@ class GatherShapeInferenceTest : public ShapeInferenceTest {
ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
{s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
+ const ProgramShape to_apply_ =
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
};
-TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
+// Shape inference tests for Gather.
+
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
@@ -1570,7 +1664,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
@@ -1585,7 +1679,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
@@ -1600,7 +1694,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
ShapeInference::InferGatherShape(
@@ -1617,7 +1711,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
ShapeInference::InferGatherShape(
@@ -1635,7 +1729,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
ShapeInference::InferGatherShape(
@@ -1653,7 +1747,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
+TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
// This is equivalent to a dynamic slice.
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
@@ -1671,7 +1765,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
+TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
// The gather indices "tensor" is a scalar S here that's used to slice out
// [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
@@ -1689,7 +1783,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
+TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1704,7 +1798,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
+TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1719,7 +1813,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
+TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1734,7 +1828,7 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1751,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1768,7 +1862,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1784,7 +1878,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1800,7 +1894,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1818,7 +1912,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1835,7 +1929,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1853,7 +1947,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1872,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1890,7 +1984,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1908,7 +2002,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1924,7 +2018,8 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowBoundsTooLarge) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1940,7 +2035,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1958,7 +2053,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1975,7 +2070,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
+TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1992,5 +2087,575 @@ TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
<< statusor.status();
}
+// Shape inference tests for Scatter.
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Bounds of the window dimensions of updates must not exceed "
+ "the bounds of the corresponding dimensions of operand."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Bounds of the window dimensions of updates must not exceed "
+ "the bounds of the corresponding dimensions of operand."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ TfScatterWithUpdatesNotMatchingIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ TfScatterWithUpdatesNotMatchingIndicesV2) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Bounds of the window dimensions of updates must not exceed "
+ "the bounds of the corresponding dimensions of operand."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ TfScatterNdWithUpdatesNotMatchingIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
+ // This is equivalent to a dynamic update slice.
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0, 1, 2, 3, 4},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
+ // The scalar indices "tensor" is a scalar S here that's used to update a
+ // [30,29,28,27] shaped tensor within the operand at position S.
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
+ ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0, 1, 2, 3},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/0)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected array argument for operand"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ ScatterWithTupleShapedScatterIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected array argument for scatter indices"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected array argument for updates"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Scatter indices parameter must be an integral tensor"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/10));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Scatter index leaf dimension must be within [0, "
+ "rank(scatter_indices) + 1)"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Updates tensor must be of rank 7; got 8."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
+ const ProgramShape invalid_update_computation =
+ ShapeUtil::MakeProgramShape({f32_}, f32_);
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
+ invalid_update_computation,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Reduction function must take 2 parameters, but takes 1"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 8, 7},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("update_window_dims in scatter op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 7},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("update_window_dims in scatter op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 9},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid update_window_dims set in scatter op; valid "
+ "range is [0, 9)"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{2, 1},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("inserted_window_dims in scatter op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 1},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("inserted_window_dims in scatter op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 5},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
+ "range is [0, 5)"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
+ "the bound of dimension index_vector_dim=4 of scatter_indices "
+ "is 5. These two numbers must be equal"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
+ "is [0, 5), got: 4->10"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
+ << statusor.status();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 92bb21b816..c0582c6a2d 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -27,6 +28,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool reusing existing stream";
}
}
@@ -34,6 +37,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Create a new stream.
stream = MakeUnique<se::Stream>(executor);
stream->Init();
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool created new stream";
}
// Return the stream wrapped in Ptr, which has our special deleter semantics.
@@ -43,12 +48,16 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
void StreamPool::ReturnStream(se::Stream* stream) {
if (stream->ok()) {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool returning ok stream";
tensorflow::mutex_lock lock(mu_);
streams_.emplace_back(stream);
} else {
- // If the stream has encountered any errors, all subsequent
- // operations on it will fail. So just delete the stream, and rely
- // on new streams to be created in the future.
+ // If the stream has encountered any errors, all subsequent operations on it
+ // will fail. So just delete the stream, and rely on new streams to be
+ // created in the future.
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool deleting !ok stream";
delete stream;
}
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0effdc80a4..0447807a41 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -232,8 +232,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement(
// Copy the points-to set (and tuple sources) at index {element_index} of the
// operand to the points-to set for this GetTupleElement instruction.
points_to_set.ForEachMutableElement(
- [&, this](const ShapeIndex& target_index,
- PointsToSet::BufferList* points_to) {
+ [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
// Construct an index into the operand by prepending element_index to
// the index for the GetTupleElement instruction's points-to set.
ShapeIndex src_index;
@@ -308,7 +307,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// Recursively copy the points to set of the operand tuple {0} to the output
// element {0}.
points_to_set.ForEachMutableElement(
- [this, &points_to_set, &operand_points_to_set](
+ [&points_to_set, &operand_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
if (index.empty() || index[0] != 0) {
return;
@@ -517,7 +516,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
const HloInstruction* instruction,
TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
GetPointsToSet(instruction)
- .ForEachElement([this, buffers, instruction](
+ .ForEachElement([buffers, instruction](
const ShapeIndex& index,
const PointsToSet::BufferList& source_buffers) {
// Add buffers which 'instruction' is the source of.
@@ -547,7 +546,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
const PointsToSet& src_points_to_set = GetPointsToSet(src);
dst_points_to_set.ForEachMutableElement(
- [this, &dst_points_to_set, &src_points_to_set](
+ [&dst_points_to_set, &src_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
*buffers = src_points_to_set.element(index);
for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2e5f646804..10d382e8ab 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1118,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
- auto make_cond = [this, &data_shape]() {
+ auto make_cond = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
@@ -1127,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
- auto make_body = [this, &data_shape]() {
+ auto make_body = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index 4391078b64..c4c958be4a 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -172,7 +172,7 @@ TEST_F(ShapeTreeTest, TupleShape) {
// Write zero to all data elements.
shape_tree.ForEachMutableElement(
- [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; });
+ [](const ShapeIndex& /*index*/, int* data) { *data = 0; });
EXPECT_EQ(0, shape_tree.element({}));
EXPECT_EQ(0, shape_tree.element({0}));
EXPECT_EQ(0, shape_tree.element({1}));
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index ec901af1e2..34869cc507 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -596,8 +596,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
};
auto comma_list_to_int64s =
- [&s,
- string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
+ [string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
for (const string& piece : tensorflow::str_util::Split(input, ',')) {
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
@@ -792,7 +791,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (LayoutUtil::IsSparseArray(shape)) {
allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout());
} else {
- CHECK(LayoutUtil::IsDenseArray(shape));
+ CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString();
tensorflow::gtl::ArraySlice<int64> padded_dimensions =
LayoutUtil::PaddedDimensions(shape);
if (!padded_dimensions.empty()) {
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index d372d1ca43..24b17b7100 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -733,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
var4D, [epsilon](float a) { return a + epsilon; });
auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
- var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
+ var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
auto grad_output_times_var =
*ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index cfd36abf47..0e9e92ed99 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -111,7 +111,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
@@ -137,7 +137,7 @@ std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
return {row_major ? 1 : 0, row_major ? 0 : 1};
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
@@ -148,7 +148,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
@@ -160,7 +160,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(
@@ -172,7 +172,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
@@ -183,7 +183,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) {
&builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto param0 =
@@ -533,7 +533,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
&builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
@@ -612,7 +612,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
{x_data.get(), y_data.get()}, this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
@@ -648,7 +648,49 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
{x_data.get(), y_data.get()}, this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
+ using T = TypeParam;
+
+ XlaBuilder builder(this->TestName());
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
+ "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
+ "y");
+
+ DotDimensionNumbers dnums;
+ dnums.add_lhs_contracting_dimensions(3);
+ dnums.add_rhs_contracting_dimensions(2);
+ dnums.add_lhs_batch_dimensions(0);
+ dnums.add_lhs_batch_dimensions(1);
+ dnums.add_rhs_batch_dimensions(0);
+ dnums.add_rhs_batch_dimensions(1);
+
+ DotGeneral(x, y, dnums);
+
+ auto x_data =
+ this->client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ {{{9.0f, 10.0f}, {11.0f, 12.0f}},
+ {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
+ .ConsumeValueOrDie();
+
+ auto y_data =
+ this->client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
+ {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
+ .ConsumeValueOrDie();
+
+ this->template ComputeAndCompareR4<T>(
+ &builder,
+ /*expected=*/
+ {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}},
+ {x_data.get(), y_data.get()}, this->error_spec_);
+}
+
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
using T = TypeParam;
for (bool transpose_lhs : {false, true}) {
for (bool transpose_rhs : {false, true}) {
@@ -708,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
}
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64,
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
DotOfConcatOptimizationWithConstLHS) {
using T = TypeParam;
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
@@ -754,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64,
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
DotOfConcatOptimizationWithConstRHS) {
using T = TypeParam;
std::unique_ptr<Array2D<T>> constant_rhs_array(
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 47cab79604..115448c908 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) {
TEST_F(LocalClientAotTest, Constant) {
xla::ExecutableRunOptions run_options;
OpaqueData opaque_data{100, 20, 3};
- void* parameters[] = {&opaque_data};
float out = 0;
- void* temporary_buffers[] = {nullptr, &out};
- SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ void* temporary_buffers[] = {&opaque_data, &out};
+ SumAndDouble(&out, &run_options, nullptr, temporary_buffers);
EXPECT_EQ(out, 246.0f);
opaque_data = {1, 2, 3};
- SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ SumAndDouble(&out, &run_options, nullptr, temporary_buffers);
EXPECT_EQ(out, 12.0f);
}
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 74494e60e8..e310966d8b 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -93,7 +93,7 @@ int main(int argc, char** argv) {
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
CHECK_EQ(result->buffer_sizes().size(), 3);
- CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
+ CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer
CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer
if (triple.isOSBinFormatELF()) {
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 029af69573..326e13b386 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -182,7 +182,7 @@ XLA_TEST_F(PrngTest, Uniformity256) {
XLA_TEST_F(PrngTest, MapUsingRng) {
// Build a x -> (x + U[0,1)) computation.
- auto build_sum_rng = [this](XlaBuilder& builder) {
+ auto build_sum_rng = [](XlaBuilder& builder) {
auto b = builder.CreateSubBuilder("sum_with_rng");
auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input");
Add(x,
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index c81c27891c..1bdf1867b9 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -1236,6 +1236,35 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
{param_value.get()}, ErrorSpec(4e-5));
}
+TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
+ auto while_shape = ShapeUtil::MakeShape(S32, {});
+
+ XlaComputation condition;
+ {
+ XlaBuilder builder("condition");
+ Parameter(&builder, 0, while_shape, "state");
+ Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
+ TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
+ }
+
+ XlaComputation body;
+ {
+ XlaBuilder builder("body");
+ auto indvar = Parameter(&builder, 0, while_shape, "state");
+ Add(indvar, ConstantR0<int32>(&builder, 1));
+ TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
+ }
+
+ XlaBuilder builder(TestName());
+ While(condition, body, ConstantR0<int32>(&builder, 0));
+
+ TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+
+ ComputeAndCompareR0<int32>(&builder, 2, {});
+}
+
void BM_WhileLoop(int num_iters) {
// Benchmark a simple kernel to measure while loop overheads.
tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 0ee8e68c88..11f3efb1f3 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -84,8 +84,8 @@ Status ParseOneProfileOutputLine(
tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore =
{}) {
string separator = "[^:]*:: +";
- string match_percentage = "\\d+\\.\\d\\d%";
- string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
+ string match_percentage = R"(\d+\.\d*% +\d+Σ)";
+ string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))";
string match_usecs = "([0-9.]+) usec";
string match_flops = "([^ ]*)";
string match_trops = "([^ ]*)";
@@ -225,7 +225,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
MaybeFind(parsed_profile_lines, "tanh"));
EXPECT_GT(total_profile.cycles, 0);
- EXPECT_EQ(total_profile.cycles_percentage, "100.00%");
+ EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ");
EXPECT_TRUE(HasFlops(total_profile));
EXPECT_TRUE(HasTrops(total_profile));
@@ -333,7 +333,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
EXPECT_GT(total_while_body_profile.cycles, 0);
EXPECT_EQ(total_while_body_profile.opcode, "[total]");
- EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
+ EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ");
EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index d7cabbe876..40d28a57bf 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -87,6 +87,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:testing",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 3bb2f3c000..563e2d8fdb 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -182,11 +183,18 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
// Run the computation num_runs times, and return the result from the last
// execution.
+ const bool xla_hlo_profile =
+ legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
StreamExecutorMemoryAllocator allocator(
client->platform(),
{client->platform()->ExecutorForDevice(0).ValueOrDie()});
tensorflow::gtl::optional<ScopedShapedBuffer> result;
for (int i = 0; i < opts.num_runs; ++i) {
+ // If xla_hlo_profile is enabled, print a noisy message before the last run,
+ // making it easier to separate this profile from the others in the logspam.
+ if (xla_hlo_profile && i == opts.num_runs - 1) {
+ LOG(INFO) << "\n\n***** Final run below ******";
+ }
ExecutionProfile profile;
ExecutableRunOptions run_options;
run_options.set_execution_profile(&profile);
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 0b300dc7b2..fd784e909c 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -447,6 +447,20 @@ message GatherDimensionNumbers {
int64 index_vector_dim = 4;
}
+// Describes the dimension numbers for a scatter operation.
+//
+// All the fields are similar to the corresponding fields in
+// GatherDimensionNumbers. Differences are noted below.
+message ScatterDimensionNumbers {
+ // The set of dimensions in the updates shape that are window dimensions.
+ repeated int64 update_window_dims = 1;
+ // The set of window dimensions that must be inserted into the updates shape.
+ repeated int64 inserted_window_dims = 2;
+
+ repeated int64 scatter_dims_to_operand_dims = 3;
+ int64 index_vector_dim = 4;
+}
+
message ConvolutionDimensionNumbers {
// The number of the dimension that represents batch in the input.
int64 input_batch_dimension = 7;
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 6a4e252b44..cc34db995e 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -107,7 +107,6 @@ py_library(
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/timeseries",
"//tensorflow/contrib/tpu",
- "//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py
index a36b3d77a9..2d1bed3367 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/contrib/autograph/converters/call_trees.py
@@ -238,7 +238,7 @@ class CallTreeTransformer(converter.Base):
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
- ag__.converted_call(func, True, False, {}, args)
+ ag__.converted_call(func, True, False, False, {}, args)
"""
call_expr = templates.replace(template, func=node.func, args=node.args)
new_call = call_expr[0].value
diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py
index ccdf79d47b..77f625bac7 100644
--- a/tensorflow/contrib/autograph/converters/directives.py
+++ b/tensorflow/contrib/autograph/converters/directives.py
@@ -42,10 +42,30 @@ def _map_args(call_node, function):
Returns:
Dict[Text, ast.AST], mapping each of the function's argument names to
the respective AST node.
+ Raises:
+ ValueError: if the default arguments are not correctly set
"""
args = call_node.args
kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
- return tf_inspect.getcallargs(function, *args, **kwds)
+ call_args = tf_inspect.getcallargs(function, *args, **kwds)
+
+ # Keyword arguments not specified in kwds will be mapped to their defaults,
+ # which are Python values. Since we don't currently have a way to transform
+ # those into AST references, we simply remove them. By convention, directives
+ # use UNSPECIFIED as default value for for optional arguments. No other
+ # defaults should be present.
+ unexpected_defaults = []
+ for k in call_args:
+ if (k not in kwds
+ and call_args[k] not in args
+ and call_args[k] is not directives.UNSPECIFIED):
+ unexpected_defaults.append(k)
+ if unexpected_defaults:
+ raise ValueError('Unexpected keyword argument values, %s, for function %s'
+ % (zip(unexpected_defaults,
+ [call_args[k] for k in unexpected_defaults]),
+ function))
+ return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
class DirectivesTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py
index a573ba5850..a2d083b891 100644
--- a/tensorflow/contrib/autograph/converters/directives_test.py
+++ b/tensorflow/contrib/autograph/converters/directives_test.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.contrib.autograph.core.converter import AgAnno
from tensorflow.contrib.autograph.lang import directives
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.python.platform import test
@@ -71,7 +72,23 @@ class DirectivesTest(converter_testing.TestCase):
d = d[directives.set_loop_options]
self.assertEqual(d['parallel_iterations'].n, 10)
self.assertEqual(d['back_prop'].id, 'a')
- self.assertEqual(d['swap_memory'], directives.UNSPECIFIED)
+ self.assertNotIn('swap_memory', d)
+
+ def test_invalid_default(self):
+
+ def invalid_directive(valid_arg, invalid_default=object()):
+ del valid_arg
+ del invalid_default
+ return
+
+ def call_invalid_directive():
+ invalid_directive(1)
+
+ node, _ = parser.parse_entity(call_invalid_directive)
+ # Find the call to the invalid directive
+ node = node.body[0].body[0].value
+ with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
+ directives_converter._map_args(node, invalid_directive)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py
index cd74e5f18f..5d61b220af 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers_test.py
+++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py
@@ -34,8 +34,10 @@ class ErrorHandlersTest(converter_testing.TestCase):
raise ValueError()
node, ctx = self.prepare(test_fn, {})
- anno.setanno(node, anno.Basic.ORIGIN,
- origin_info.OriginInfo(None, None, None))
+ anno.setanno(
+ node, anno.Basic.ORIGIN,
+ origin_info.OriginInfo(None, 'test_function_name', 'test_code',
+ 'test_comment'))
node = error_handlers.transform(node, ctx)
with self.compiled(node, {}) as result:
with self.assertRaises(errors.GraphConstructionError):
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py
index a93e4a8064..83a80c1f52 100644
--- a/tensorflow/contrib/autograph/core/converter.py
+++ b/tensorflow/contrib/autograph/core/converter.py
@@ -233,7 +233,7 @@ class Base(transformer.Base):
arg_values = []
for def_ in defs:
if (directive not in def_.directives or
- arg not in arg not in def_.directives[directive]):
+ arg not in def_.directives[directive]):
continue
arg_value = def_.directives[directive][arg]
for prev_value in arg_values:
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py
index c0e2c74e47..404c1f5456 100644
--- a/tensorflow/contrib/autograph/core/errors_test.py
+++ b/tensorflow/contrib/autograph/core/errors_test.py
@@ -43,7 +43,8 @@ class RuntimeErrorsTest(test.TestCase):
filename = tf_inspect.getsourcefile(function)
lineno += line_offset
loc = origin_info.LineLocation(filename, lineno)
- origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
+ origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code',
+ 'test_comment')
return loc, origin
def test_improved_errors_basic(self):
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
index 73125eb452..7e7ef5a3e2 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
+++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
@@ -44,6 +44,33 @@ class ModelWithStaticConditional(object):
return x
+class BasicBlock(tf.keras.Model):
+
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+ self.conv1 = tf.keras.layers.Conv2D(8, 3)
+ self.pool = tf.keras.layers.GlobalAveragePooling2D()
+ self.dense = tf.keras.layers.Dense(3)
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = self.dense(x)
+ return x
+
+
+class CompoundModel(tf.keras.Model):
+
+ def __init__(self):
+ super(CompoundModel, self).__init__()
+ self.block = BasicBlock()
+
+ @autograph.convert(recursive=True)
+ def call(self, x):
+ x = self.block(x) # pylint: disable=not-callable
+ return x
+
+
class KerasTest(tf.test.TestCase):
def test_basic(self):
@@ -57,6 +84,20 @@ class KerasTest(tf.test.TestCase):
model = ModelWithStaticConditional(True)
self.assertEqual(model.call(), 25)
+ def test_recursive_true(self):
+ with self.assertRaisesRegexp(NotImplementedError,
+ 'Object conversion is not yet supported.'):
+ with tf.Graph().as_default():
+ model = CompoundModel()
+ model.build(tf.TensorShape((None, 10, 10, 1)))
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ sample_input = tf.random_uniform((1, 10, 10, 1))
+ output = model(sample_input) # pylint: disable=not-callable
+ self.assertEqual(sess.run(output).shape, (1, 3))
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 0adff76a9f..4729c735c6 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -68,7 +68,8 @@ def convert(recursive=False, verbose=False, arg_types=None):
@wraps(f)
def wrapper(*args, **kwargs):
- return converted_call(f, recursive, verbose, arg_types, *args, **kwargs)
+ return converted_call(f, recursive, verbose, True, arg_types, *args,
+ **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@@ -129,12 +130,12 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
return decorator
-def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
+def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
+ **kwargs):
"""Compiles a function call inline."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
-
- if conversion.is_whitelisted_for_graph(f):
+ if not force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 754baa87b0..803fde9089 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -183,8 +183,8 @@ class ApiTest(test.TestCase):
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
- x //= api.converted_call(self.called_member, False, False, {}, self,
- a)
+ x //= api.converted_call(self.called_member, False, False, False, {},
+ self, a)
return x
tc = TestClass()
@@ -195,7 +195,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_converted_call_builtin(self):
- x = api.converted_call(range, False, False, {}, 3)
+ x = api.converted_call(range, False, False, False, {}, 3)
self.assertEqual((0, 1, 2), tuple(x))
def test_converted_call_function(self):
@@ -206,7 +206,7 @@ class ApiTest(test.TestCase):
return x
with self.test_session() as sess:
- x = api.converted_call(test_fn, False, False, {},
+ x = api.converted_call(test_fn, False, False, False, {},
constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
@@ -224,7 +224,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc.test_method, False, False, {}, tc)
+ x = api.converted_call(tc.test_method, False, False, False, {}, tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_method_by_class(self):
@@ -241,7 +241,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(TestClass.test_method, False, False, {}, tc)
+ x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_callable_object(self):
@@ -258,7 +258,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc, False, False, {})
+ x = api.converted_call(tc, False, False, False, {})
self.assertEqual(1, sess.run(x))
def test_converted_call_constructor(self):
@@ -274,7 +274,7 @@ class ApiTest(test.TestCase):
return self.x
with self.test_session() as sess:
- tc = api.converted_call(TestClass, False, False, {},
+ tc = api.converted_call(TestClass, False, False, False, {},
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
@@ -286,11 +286,12 @@ class ApiTest(test.TestCase):
return x == 0
with self.test_session() as sess:
- x = api.converted_call(f, False, False, {}, constant_op.constant(0))
+ x = api.converted_call(f, False, False, False, {},
+ constant_op.constant(0))
self.assertTrue(sess.run(x))
converted_f = api.to_graph(f)
- x = api.converted_call(converted_f, False, False, {},
+ x = api.converted_call(converted_f, False, False, False, {},
constant_op.constant(0))
self.assertTrue(sess.run(x))
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index afb10d4d8b..fc8a976d3f 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -118,6 +118,17 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types):
node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
elif tf_inspect.ismethod(o):
node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
+ # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
+ elif hasattr(o, '__class__'):
+ raise NotImplementedError(
+ 'Object conversion is not yet supported. If you are '
+ 'trying to convert code that uses an existing object, '
+ 'try including the creation of that object in the '
+ 'conversion. For example, instead of converting the method '
+ 'of a class, try converting the entire class instead. '
+ 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
+ 'contrib/autograph/README.md#using-the-functional-api '
+ 'for more information.')
else:
raise ValueError(
'Entity "%s" has unsupported type "%s". Only functions and classes are '
@@ -181,7 +192,7 @@ def class_to_graph(c, program_ctx):
class_name = namer.compiled_class_name(c.__name__, c)
# TODO(mdan): This needs to be explained more thoroughly.
- # Process any base classes: if the sueprclass if of a whitelisted type, an
+ # Process any base classes: if the superclass if of a whitelisted type, an
# absolute import line is generated. Otherwise, it is marked for conversion
# (as a side effect of the call to namer.compiled_class_name() followed by
# program_ctx.update_name_map(namer)).
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index 1c5d4d09c4..86432573a7 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -50,7 +50,7 @@ class ConversionTest(test.TestCase):
self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant))
def test_entity_to_graph_unsupported_types(self):
- with self.assertRaises(ValueError):
+ with self.assertRaises(NotImplementedError):
program_ctx = self._simple_program_ctx()
conversion.entity_to_graph('dummy', program_ctx, None, None)
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 988df70157..be38d3f534 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -212,12 +212,12 @@ def if_stmt(cond, body, orelse):
Tuple containing the statement outputs.
"""
if tensor_util.is_tensor(cond):
- return _tf_if_stmt(cond, body, orelse)
+ return tf_if_stmt(cond, body, orelse)
else:
return _py_if_stmt(cond, body, orelse)
-def _tf_if_stmt(cond, body, orelse):
+def tf_if_stmt(cond, body, orelse):
"""Overload of if_stmt that stages a TF cond."""
return control_flow_ops.cond(cond, body, orelse)
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
index 9f98e48a6a..b60651a30e 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -18,8 +18,10 @@ from __future__ import division
from __future__ import print_function
import collections
+import tokenize
import gast
+import six
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
@@ -56,13 +58,14 @@ class Location(
class OriginInfo(
collections.namedtuple(
'OriginInfo',
- ('loc', 'function_name', 'source_code_line'))):
+ ('loc', 'function_name', 'source_code_line', 'comment'))):
"""Container for information about the source code before conversion.
Attributes:
loc: Location
function_name: Optional[Text]
source_code_line: Text
+ comment: Optional[Text]
"""
def as_frame(self):
@@ -152,6 +155,15 @@ def resolve(nodes, source, function=None):
function_lineno = None
function_filepath = None
+ # TODO(mdan): Pull this to a separate utility.
+ code_reader = six.StringIO(source)
+ comment_map = {}
+ for token in tokenize.generate_tokens(code_reader.readline):
+ tok_type, tok_string, loc, _, _ = token
+ srow, _ = loc
+ if tok_type == tokenize.COMMENT:
+ comment_map[srow] = tok_string.strip()[1:].strip()
+
source_lines = source.split('\n')
for node in nodes:
for n in gast.walk(node):
@@ -162,12 +174,13 @@ def resolve(nodes, source, function=None):
source_code_line = source_lines[lineno_in_body - 1]
if function:
- source_lineno = function_lineno + lineno_in_body - 1
+ source_lineno = function_lineno + lineno_in_body
function_name = function.__name__
else:
source_lineno = lineno_in_body
function_name = None
location = Location(function_filepath, source_lineno, n.col_offset)
- origin = OriginInfo(location, function_name, source_code_line)
+ origin = OriginInfo(location, function_name,
+ source_code_line, comment_map.get(source_lineno))
anno.setanno(n, anno.Basic.ORIGIN, origin)
diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py
index 6d7d8b1622..eeaa13007e 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py
@@ -85,16 +85,19 @@ class OriginInfoTest(test.TestCase):
self.assertEqual(origin.loc.lineno, 1)
self.assertEqual(origin.loc.col_offset, 0)
self.assertEqual(origin.source_code_line, 'def test_fn(x):')
+ self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 2)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' """Docstring."""')
+ self.assertIsNone(origin.comment)
origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
self.assertEqual(origin.loc.lineno, 3)
self.assertEqual(origin.loc.col_offset, 2)
self.assertEqual(origin.source_code_line, ' return x # comment')
+ self.assertEqual(origin.comment, 'comment')
if __name__ == '__main__':
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
index d7c71a20ed..88a3909de4 100644
--- a/tensorflow/contrib/bigtable/README.md
+++ b/tensorflow/contrib/bigtable/README.md
@@ -324,7 +324,7 @@ If you encounter a log line that includes the following:
"filename":"/usr/share/grpc/roots.pem"
```
-you likely need to copy the [gRPC roots.pem file][grpcPem] to
+you likely need to copy the [gRPC `roots.pem` file][grpcPem] to
`/usr/share/grpc/roots.pem` on your local machine.
[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem
@@ -338,7 +338,10 @@ are available.
- **Compute Engine**: When running on Compute Engine, the client will often use
the service account from the virtual machine's metadata service. Be sure to
authorize your Compute Engine VM to have access to the Cloud Bigtable service
- when creating your VM.
+ when creating your VM, or [update the VM's scopes][update-vm-scopes] on a
+ running VM if you run into this issue.
- **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service
account dedicated to your GCP project. Ensure the service account has been
authorized via the Cloud Console to access your Cloud Bigtable instances.
+
+[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index fd30aa8bbb..e6ef513c40 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""The Python API for TensorFlow's Bigtable integration.
+"""The Python API for TensorFlow's Cloud Bigtable integration.
TensorFlow has support for reading from and writing to Cloud Bigtable. To use
-the Bigtable TensorFlow integration, first create a BigtableClient (which
-configures your connection to Cloud Bigtable), and then open a Table. The Table
-object then allows you to create numerous @{tf.data.Dataset}s to read data, or
-write a @{tf.data.Dataset} object to the underlying Bigtable Table.
+TensorFlow + Cloud Bigtable integration, first create a BigtableClient to
+configure your connection to Cloud Bigtable, and then create a BigtableTable
+object to allow you to create numerous @{tf.data.Dataset}s to read data, or
+write a @{tf.data.Dataset} object to the underlying Cloud Bigtable table.
-For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable.
+For background on Cloud Bigtable, see: https://cloud.google.com/bigtable .
"""
from __future__ import absolute_import
@@ -48,7 +48,7 @@ class BigtableClient(object):
"""BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
- `table` method to open a Bigtable Table.
+ `table` method to open a Bigtable table.
"""
def __init__(self,
@@ -94,7 +94,7 @@ class BigtableClient(object):
project_id, instance_id, connection_pool_size, max_receive_message_size)
def table(self, name, snapshot=None):
- """Opens a table and returns a `BigtableTable` object.
+ """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object.
Args:
name: A `tf.string` `tf.Tensor` name of the table to open.
@@ -102,8 +102,8 @@ class BigtableClient(object):
request the creation of a snapshot. (Note: currently unimplemented.)
Returns:
- A `BigtableTable` python object representing the operations available on
- the table.
+ A `tf.contrib.bigtable.BigtableTable` Python object representing the
+ operations available on the table.
"""
# TODO(saeta): Implement snapshot functionality.
table = gen_bigtable_ops.bigtable_table(self._resource, name)
@@ -133,7 +133,8 @@ class BigtableTable(object):
"""Retrieves the values of columns for a dataset of keys.
Example usage:
- ```
+
+ ```python
table = bigtable_client.table("my_table")
key_dataset = table.get_keys_prefix("imagenet")
images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
@@ -144,7 +145,8 @@ class BigtableTable(object):
Alternatively, you can use keyword arguments to specify the columns to
capture. Example (same as above, rewritten):
- ```
+
+ ```python
table = bigtable_client.table("my_table")
key_dataset = table.get_keys_prefix("imagenet")
images = key_dataset.apply(table.lookup_columns(
@@ -152,15 +154,17 @@ class BigtableTable(object):
training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
```
- Note: certain kwargs keys are reserved, and thus some column families cannot
- be identified using the kwargs syntax. Instead, please use the args syntax.
- This list includes:
+ Note: certain `kwargs` keys are reserved, and thus, some column families
+ cannot be identified using the `kwargs` syntax. Instead, please use the
+ `args` syntax. This list includes:
+
- 'name'
- This list can change at any time.
+
+ Note: this list can change at any time.
Args:
*args: A list of tuples containing (column family, column name) pairs.
- **kwargs: Column families and
+ **kwargs: Column families (keys) and column qualifiers (values).
Returns:
A function that can be passed to `tf.data.Dataset.apply` to retrieve the
@@ -712,7 +716,7 @@ class _BigtableScanDataset(dataset_ops.Dataset):
class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
- """_BigtableKeyRangeDataset returns key pairs from the Bigtable.
+ """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
"""
def __init__(self, table, prefix, start, end):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index f4a375328e..5fcb19a47a 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -191,7 +191,7 @@ py_test(
py_test(
name = "estimator_test",
- size = "medium",
+ size = "large",
srcs = ["estimator_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index dbfa69edcb..194a5c8754 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -86,7 +86,8 @@ def _dnn_tree_combined_model_fn(
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
- output_type=model.ModelBuilderOutputType.MODEL_FN_OPS):
+ output_type=model.ModelBuilderOutputType.MODEL_FN_OPS,
+ override_global_step_value=None):
"""DNN and GBDT combined model_fn.
Args:
@@ -135,6 +136,12 @@ def _dnn_tree_combined_model_fn(
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
Returns:
A `ModelFnOps` object.
@@ -350,7 +357,8 @@ def _dnn_tree_combined_model_fn(
trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
tree_train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
+ finalized_trees,
+ override_global_step_value)
])
return model_fn_ops
@@ -378,7 +386,8 @@ def _dnn_tree_combined_model_fn(
trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
tree_spec.train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
+ finalized_trees,
+ override_global_step_value)
]
fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
list(fusion_spec.training_hooks))
@@ -411,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedClassifier instance.
Args:
@@ -467,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
head = head_lib.multi_class_head(
n_classes=n_classes,
@@ -497,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedClassifier, self).__init__(
model_fn=_model_fn,
@@ -531,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedRegressor instance.
Args:
@@ -587,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -622,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedRegressor, self).__init__(
model_fn=_model_fn,
@@ -657,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedEstimator instance.
Args:
@@ -708,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
def _model_fn(features, labels, mode, config):
@@ -732,7 +758,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn,
@@ -832,7 +859,8 @@ class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
- use_core_versions=True)
+ use_core_versions=True,
+ override_global_step_value=None)
super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 2df879f924..870ce2442b 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -22,8 +22,10 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.estimator.canned import head as core_head_lib
from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.losses import losses as core_losses
# ================== Old estimator interface===================================
@@ -49,7 +51,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -83,6 +86,14 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
Raises:
ValueError: If learner_config is not valid.
@@ -123,6 +134,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -146,7 +158,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -180,6 +193,14 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -203,6 +224,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -228,7 +250,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -258,6 +281,14 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -272,6 +303,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -281,24 +313,23 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
class GradientBoostedDecisionTreeRanker(estimator.Estimator):
"""A ranking estimator using gradient boosted decision trees."""
- def __init__(
- self,
- learner_config,
- examples_per_layer,
- head,
- ranking_model_pair_keys,
- num_trees=None,
- feature_columns=None,
- weight_column_name=None,
- model_dir=None,
- config=None,
- label_keys=None,
- feature_engineering_fn=None,
- logits_modifier_function=None,
- center_bias=False,
- use_core_libs=False,
- output_leaf_index=False,
- ):
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ use_core_libs=False,
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -338,7 +369,14 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
-
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
Raises:
ValueError: If learner_config is not valid.
"""
@@ -357,6 +395,7 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -366,6 +405,25 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
# The estimators below use new core Estimator interface and must be used with
# new feature columns and heads.
+# For multiclass classification, use the following head since it uses loss
+# that is twice differentiable.
+def core_multiclass_head(n_classes):
+ """Core head for multiclass problems."""
+
+ def loss_fn(labels, logits):
+ result = losses.per_example_maxent_loss(
+ labels=labels, logits=logits, weights=None, num_classes=n_classes)
+ return result[0]
+
+ # pylint:disable=protected-access
+ head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=n_classes,
+ loss_fn=loss_fn,
+ loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ # pylint:enable=protected-access
+
+ return head_fn
+
class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
"""An estimator using gradient boosted decision trees.
@@ -435,6 +493,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
+ 'override_global_step_value': None
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -445,22 +504,20 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
"""A ranking estimator using gradient boosted decision trees."""
- def __init__(
- self,
- learner_config,
- examples_per_layer,
- head,
- ranking_model_pair_keys,
- num_trees=None,
- feature_columns=None,
- weight_column_name=None,
- model_dir=None,
- config=None,
- label_keys=None,
- logits_modifier_function=None,
- center_bias=False,
- output_leaf_index=False,
- ):
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -519,6 +576,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
+ 'override_global_step_value': None
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 9e9febbbef..68d710d713 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -25,10 +25,12 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
+from tensorflow.python.training import checkpoint_utils
def _train_input_fn():
@@ -37,6 +39,15 @@ def _train_input_fn():
return features, label
+def _multiclass_train_input_fn():
+ features = {
+ "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]])
+ }
+ label = constant_op.constant(
+ [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32)
+ return features, label
+
+
def _ranking_train_input_fn():
features = {
"a.f1": constant_op.constant([[3.], [0.3], [1.]]),
@@ -68,6 +79,10 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
self._export_dir_base = tempfile.mkdtemp() + "export/"
gfile.MkDir(self._export_dir_base)
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
def testFitAndEvaluateDontThrowException(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
@@ -202,6 +217,126 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
model.evaluate(input_fn=_ranking_train_input_fn, steps=1)
model.predict(input_fn=_infer_ranking_train_input_fn)
+ def testDoesNotOverrideGlobalSteps(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 2
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")],
+ output_leaf_index=False)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ # When no override of global steps, 5 steps were used.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
+
+ def testOverridesGlobalSteps(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 2
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")],
+ output_leaf_index=False,
+ override_global_step_value=10000000)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ self._assert_checkpoint(classifier.model_dir, global_step=10000000)
+
+ def testFitAndEvaluateMultiClassTreePerClassDontThrowException(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 3
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.multi_class_strategy = (
+ learner_pb2.LearnerConfig.TREE_PER_CLASS)
+
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ n_classes=learner_config.num_classes,
+ num_trees=1,
+ examples_per_layer=7,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")])
+
+ classifier.fit(input_fn=_multiclass_train_input_fn, steps=100)
+ classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+ classifier.export(self._export_dir_base)
+ result_iter = classifier.predict(input_fn=_eval_input_fn)
+ for prediction_dict in result_iter:
+ self.assertTrue("classes" in prediction_dict)
+
+ def testFitAndEvaluateMultiClassDiagonalDontThrowException(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 3
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.multi_class_strategy = (
+ learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
+
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ n_classes=learner_config.num_classes,
+ num_trees=1,
+ examples_per_layer=7,
+ model_dir=model_dir,
+ config=config,
+ center_bias=False,
+ feature_columns=[contrib_feature_column.real_valued_column("x")])
+
+ classifier.fit(input_fn=_multiclass_train_input_fn, steps=100)
+ classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+ classifier.export(self._export_dir_base)
+ result_iter = classifier.predict(input_fn=_eval_input_fn)
+ for prediction_dict in result_iter:
+ self.assertTrue("classes" in prediction_dict)
+
+ def testFitAndEvaluateMultiClassFullDontThrowException(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 3
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.multi_class_strategy = (
+ learner_pb2.LearnerConfig.FULL_HESSIAN)
+
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ n_classes=learner_config.num_classes,
+ num_trees=1,
+ examples_per_layer=7,
+ model_dir=model_dir,
+ config=config,
+ center_bias=False,
+ feature_columns=[contrib_feature_column.real_valued_column("x")])
+
+ classifier.fit(input_fn=_multiclass_train_input_fn, steps=100)
+ classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+ classifier.export(self._export_dir_base)
+ result_iter = classifier.predict(input_fn=_eval_input_fn)
+ for prediction_dict in result_iter:
+ self.assertTrue("classes" in prediction_dict)
+
class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
@@ -257,6 +392,87 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
est.evaluate(input_fn=_ranking_train_input_fn, steps=1)
est.predict(input_fn=_infer_ranking_train_input_fn)
+ def testFitAndEvaluateMultiClassTreePerClasssDontThrowException(self):
+ n_classes = 3
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = n_classes
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.multi_class_strategy = (
+ learner_pb2.LearnerConfig.TREE_PER_CLASS)
+
+ head_fn = estimator.core_multiclass_head(n_classes=n_classes)
+
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ learner_config=learner_config,
+ head=head_fn,
+ num_trees=1,
+ center_bias=False,
+ examples_per_layer=7,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")])
+
+ classifier.train(input_fn=_multiclass_train_input_fn, steps=100)
+ classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
+ classifier.predict(input_fn=_eval_input_fn)
+
+ def testFitAndEvaluateMultiClassDiagonalDontThrowException(self):
+ n_classes = 3
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = n_classes
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.multi_class_strategy = (
+ learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
+
+ head_fn = estimator.core_multiclass_head(n_classes=n_classes)
+
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ learner_config=learner_config,
+ head=head_fn,
+ num_trees=1,
+ center_bias=False,
+ examples_per_layer=7,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")])
+
+ classifier.train(input_fn=_multiclass_train_input_fn, steps=100)
+ classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
+ classifier.predict(input_fn=_eval_input_fn)
+
+ def testFitAndEvaluateMultiClassFullDontThrowException(self):
+ n_classes = 3
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = n_classes
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.multi_class_strategy = (
+ learner_pb2.LearnerConfig.FULL_HESSIAN)
+
+ head_fn = estimator.core_multiclass_head(n_classes=n_classes)
+
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ learner_config=learner_config,
+ head=head_fn,
+ num_trees=1,
+ center_bias=False,
+ examples_per_layer=7,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")])
+
+ classifier.train(input_fn=_multiclass_train_input_fn, steps=100)
+ classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
+ classifier.predict(input_fn=_eval_input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 161cc42cb0..04b46c3483 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -58,6 +58,10 @@ def model_builder(features,
* weight_column_name: The name of weight column.
* center_bias: Whether a separate tree should be created for first fitting
the bias.
+ * override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
@@ -76,6 +80,7 @@ def model_builder(features,
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
+ override_global_step_value = params.get("override_global_step_value", None)
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -136,7 +141,8 @@ def model_builder(features,
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees,
+ override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
@@ -206,6 +212,10 @@ def ranking_model_builder(features,
for left and right part of the training pairs for ranking. For example,
for an Example with features "a.f1" and "b.f1", the keys would be
("a", "b").
+ * override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
@@ -226,6 +236,7 @@ def ranking_model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
+ override_global_step_value = params.get("override_global_step_value", None)
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -347,7 +358,8 @@ def ranking_model_builder(features,
gbdt_model_main.get_number_of_trees_tensor())
training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees,
+ override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
index 2e4151cac4..f137ada355 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArg
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.training.summary_io import SummaryWriterCache
@@ -150,12 +151,23 @@ class FeedFnHook(session_run_hook.SessionRunHook):
class StopAfterNTrees(session_run_hook.SessionRunHook):
"""Stop training after building N full trees."""
- def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor):
+ def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor,
+ override_global_step_value=None):
self._num_trees = n
# num_attempted_trees_tensor and num_finalized_trees_tensor are both
# tensors.
self._num_attempted_trees_tensor = num_attempted_trees_tensor
self._num_finalized_trees_tensor = num_finalized_trees_tensor
+ self._override_global_step_value = override_global_step_value
+
+ def begin(self):
+ self._global_step_tensor = training_util.get_global_step()
+ if self._global_step_tensor is None:
+ raise RuntimeError("Global step should be created.")
+
+ if self._override_global_step_value is not None:
+ self._override_global_step_op = state_ops.assign(
+ self._global_step_tensor, self._override_global_step_value)
def before_run(self, run_context):
del run_context # unused by StopTrainingAfterNTrees.
@@ -175,6 +187,9 @@ class StopAfterNTrees(session_run_hook.SessionRunHook):
num_attempted_trees > 2 * self._num_trees):
logging.info("Requesting stop since we have reached %d trees.",
num_finalized_trees)
+ if self._override_global_step_value is not None:
+ logging.info("Overriding global steps value.")
+ run_context.session.run(self._override_global_step_op)
run_context.request_stop()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 19e053fcb6..ba5ef700c5 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -353,6 +353,9 @@ class GradientBoostedDecisionTreeModel(object):
self._gradient_shape = tensor_shape.scalar()
self._hessian_shape = tensor_shape.scalar()
else:
+ if center_bias:
+ raise ValueError("Center bias should be False for multiclass.")
+
self._gradient_shape = tensor_shape.TensorShape([logits_dimension])
if (learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.FULL_HESSIAN):
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index f9dc3effd0..1ab150d74a 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -148,6 +148,9 @@ class TPUClusterResolver(ClusterResolver):
else:
tpu = self._envVarFallback()
+ if tpu is None:
+ raise ValueError('Please provide a TPU Name to connect to.')
+
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name
self._credentials = credentials
diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake
index 45a0096085..33bb31148d 100644
--- a/tensorflow/contrib/cmake/external/eigen.cmake
+++ b/tensorflow/contrib/cmake/external/eigen.cmake
@@ -19,6 +19,12 @@
# build_file = "eigen.BUILD",
#)
+option(eigen_PATCH_FILE "Patch file to apply to eigen" OFF)
+set(eigen_PATCH_COMMAND "")
+if(eigen_PATCH_FILE)
+ set(eigen_PATCH_COMMAND PATCH_COMMAND patch -p0 -i "${eigen_PATCH_FILE}")
+endif(eigen_PATCH_FILE)
+
include (ExternalProject)
# We parse the current Eigen version and archive hash from the bazel configuration
@@ -45,6 +51,7 @@ ExternalProject_Add(eigen
URL ${eigen_URL}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
INSTALL_DIR "${eigen_INSTALL}"
+ ${eigen_PATCH_COMMAND}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake
index a6e8a38d8c..7d260b85f2 100644
--- a/tensorflow/contrib/cmake/external/highwayhash.cmake
+++ b/tensorflow/contrib/cmake/external/highwayhash.cmake
@@ -20,14 +20,6 @@ set(highwayhash_TAG be5edafc2e1a455768e260ccd68ae7317b6690ee)
set(highwayhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/src/highwayhash)
set(highwayhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/install)
-# put highwayhash includes in the directory where they are expected
-add_custom_target(highwayhash_create_destination_dir
- COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash
- DEPENDS highwayhash)
-
-add_custom_target(highwayhash_copy_headers_to_destination
- DEPENDS highwayhash_create_destination_dir)
-
if(WIN32)
set(highwayhash_HEADERS "${highwayhash_BUILD}/highwayhash/*.h")
set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/highwayhash.lib)
@@ -36,6 +28,20 @@ else()
set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/libhighwayhash.a)
endif()
+set(highwayhash_HEADERS
+ "${highwayhash_INSTALL}/include/code_annotation.h"
+ "${highwayhash_INSTALL}/include/highway_tree_hash.h"
+ "${highwayhash_INSTALL}/include/scalar_highway_tree_hash.h"
+ "${highwayhash_INSTALL}/include/scalar_sip_tree_hash.h"
+ "${highwayhash_INSTALL}/include/sip_hash.h"
+ "${highwayhash_INSTALL}/include/sip_tree_hash.h"
+ "${highwayhash_INSTALL}/include/sse41_highway_tree_hash.h"
+ "${highwayhash_INSTALL}/include/state_helpers.h"
+ "${highwayhash_INSTALL}/include/types.h"
+ "${highwayhash_INSTALL}/include/vec.h"
+ "${highwayhash_INSTALL}/include/vec2.h"
+)
+
ExternalProject_Add(highwayhash
PREFIX highwayhash
GIT_REPOSITORY ${highwayhash_URL}
@@ -50,5 +56,15 @@ ExternalProject_Add(highwayhash
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_INSTALL_PREFIX:STRING=${highwayhash_INSTALL})
-add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy_directory ${highwayhash_INSTALL}/include/ ${highwayhash_INCLUDE_DIR}/highwayhash)
+# put highwayhash includes in the directory where they are expected
+add_custom_target(highwayhash_create_destination_dir
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash
+ DEPENDS highwayhash)
+
+add_custom_target(highwayhash_copy_headers_to_destination
+ DEPENDS highwayhash_create_destination_dir)
+
+foreach(header_file ${highwayhash_HEADERS})
+ add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${highwayhash_INCLUDE_DIR}/highwayhash/)
+endforeach()
diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake
index eba3bcfc79..1d638e6402 100644
--- a/tensorflow/contrib/cmake/external/nsync.cmake
+++ b/tensorflow/contrib/cmake/external/nsync.cmake
@@ -20,14 +20,6 @@ set(nsync_TAG 1.20.0)
set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync)
set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install)
-# put nsync includes in the directory where they are expected
-add_custom_target(nsync_create_destination_dir
- COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR}
- DEPENDS nsync)
-
-add_custom_target(nsync_copy_headers_to_destination
- DEPENDS nsync_create_destination_dir)
-
if(WIN32)
set(nsync_HEADERS "${nsync_BUILD}/public/*.h")
set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib)
@@ -49,7 +41,35 @@ ExternalProject_Add(nsync
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL}
- -DNSYNC_LANGUAGE:STRING=c++11)
+ -DNSYNC_LANGUAGE:STRING=c++11)
+
+set(nsync_HEADERS
+ "${nsync_INSTALL}/include/nsync.h"
+ "${nsync_INSTALL}/include/nsync_atomic.h"
+ "${nsync_INSTALL}/include/nsync_counter.h"
+ "${nsync_INSTALL}/include/nsync_cpp.h"
+ "${nsync_INSTALL}/include/nsync_cv.h"
+ "${nsync_INSTALL}/include/nsync_debug.h"
+ "${nsync_INSTALL}/include/nsync_mu.h"
+ "${nsync_INSTALL}/include/nsync_mu_wait.h"
+ "${nsync_INSTALL}/include/nsync_note.h"
+ "${nsync_INSTALL}/include/nsync_once.h"
+ "${nsync_INSTALL}/include/nsync_time.h"
+ "${nsync_INSTALL}/include/nsync_time_internal.h"
+ "${nsync_INSTALL}/include/nsync_waiter.h"
+)
+
+# put nsync includes in the directory where they are expected
+add_custom_target(nsync_create_destination_dir
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR}
+ DEPENDS nsync)
+
+add_custom_target(nsync_copy_headers_to_destination
+ DEPENDS nsync_create_destination_dir)
+
+foreach(header_file ${nsync_HEADERS})
+ add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${nsync_INCLUDE_DIR}/)
+endforeach()
+
-add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy_directory ${nsync_INSTALL}/include/ ${nsync_INCLUDE_DIR}/)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 75e00f3267..9045290679 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -115,7 +115,6 @@ tensorflow/contrib/coder
tensorflow/contrib/coder/kernels
tensorflow/contrib/coder/ops
tensorflow/contrib/coder/python
-tensorflow/contrib/coder/python/layers
tensorflow/contrib/coder/python/ops
tensorflow/contrib/compiler
tensorflow/contrib/constrained_optimization
diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD
index a2c6e41303..855c824ead 100644
--- a/tensorflow/contrib/coder/BUILD
+++ b/tensorflow/contrib/coder/BUILD
@@ -1,5 +1,5 @@
# Description:
-# Contains tools related to data compression.
+# Contains ops related to data compression.
package(default_visibility = [
"//learning/brain:__subpackages__",
@@ -168,7 +168,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":coder_ops_py",
- ":entropybottleneck_py",
],
)
@@ -205,44 +204,3 @@ tf_py_test(
],
main = "python/ops/coder_ops_test.py",
)
-
-py_library(
- name = "entropybottleneck_py",
- srcs = [
- "python/layers/entropybottleneck.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":coder_ops_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/keras:engine",
- "//third_party/py/numpy",
- ],
-)
-
-tf_py_test(
- name = "entropybottleneck_py_test",
- srcs = [
- "python/layers/entropybottleneck_test.py",
- ],
- additional_deps = [
- ":entropybottleneck_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:variables",
- "//tensorflow/python:training",
- ],
- main = "python/layers/entropybottleneck_test.py",
-)
diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md
deleted file mode 100644
index c6c379c458..0000000000
--- a/tensorflow/contrib/coder/README.md
+++ /dev/null
@@ -1,73 +0,0 @@
-# Entropy coder
-
-This module contains range encoder and range decoder which can encode integer
-data into string with cumulative distribution functions (CDF).
-
-## Data and CDF values
-
-The data to be encoded should be non-negative integers in half-open interval
-`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1`
-where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision`
-is an attribute in range `0 < precision <= 16`. The function `f` maps real
-values into integers, e.g., round or floor. It is important that to encode a
-number `i`, `CDF(i + 1) - CDF(i)` cannot be zero.
-
-Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always.
-
-## RangeEncode: data shapes and CDF shapes
-
-For each data element, its CDF has to be provided. Therefore if the shape of CDF
-should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data`
-is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the
-CDF tensor should have shape (10, 10, 65).
-
-This may make CDF tensor too large, and in many applications all data elements
-may have the same probability distribution. To handle this, `RangeEncode`
-supports limited broadcasting CDF into data. Broadcasting is limited in the
-following sense:
-
-- All CDF axes but the last one is broadcasted into data but not the other way
- around,
-- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`.
-
-In the previous example where data has shape (10, 10), the following are
-acceptable CDF shapes:
-
-- (10, 10, 65)
-- (1, 10, 65)
-- (10, 1, 65)
-- (1, 1, 65)
-
-## RangeDecode
-
-`RangeEncode` encodes neither data shape nor termination character. Therefore
-the decoder should know how many characters are encoded into the string, and
-`RangeDecode` takes the encoded data shape as the second argument. The same
-shape restrictions as `RangeEncode` inputs apply here.
-
-## Example
-
-```python
-data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32)
-
-histogram = tf.bincount(data, minlength=10, maxlength=10)
-cdf = tf.cumsum(histogram, exclusive=False)
-# CDF should have length m + 1.
-cdf = tf.pad(cdf, [[1, 0]])
-# CDF axis count must be one more than data.
-cdf = tf.reshape(cdf, [1, 1, -1])
-
-# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14.
-data = tf.cast(data, tf.int16)
-encoded = coder.range_encode(data, cdf, precision=14)
-decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14)
-
-# data and decoded should be the same.
-sess = tf.Session()
-x, y = sess.run((data, decoded))
-assert np.all(x == y)
-```
-
-## Authors
-Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston
-(github: [nmjohn](https://github.com/nmjohn))
diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py
index 99b8ac7595..8897312046 100644
--- a/tensorflow/contrib/coder/__init__.py
+++ b/tensorflow/contrib/coder/__init__.py
@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Data compression tools."""
+"""Data compression ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
-from tensorflow.contrib.coder.python.layers.entropybottleneck import *
from tensorflow.contrib.coder.python.ops.coder_ops import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py
deleted file mode 100644
index 0c997bd4fd..0000000000
--- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py
+++ /dev/null
@@ -1,697 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Entropy bottleneck layer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.coder.python.ops import coder_ops
-
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine import base_layer
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.summary import summary
-
-
-class EntropyBottleneck(base_layer.Layer):
- """Entropy bottleneck layer.
-
- This layer can be used to model the entropy (the amount of information
- conveyed) of the tensor passing through it. During training, this can be used
- to impose a (soft) entropy constraint on its activations, limiting the amount
- of information flowing through the layer. Note that this is distinct from
- other types of bottlenecks, which reduce the dimensionality of the space, for
- example. Dimensionality reduction does not limit the amount of information,
- and does not enable efficient data compression per se.
-
- After training, this layer can be used to compress any input tensor to a
- string, which may be written to a file, and to decompress a file which it
- previously generated back to a reconstructed tensor (possibly on a different
- machine having access to the same model checkpoint). The entropies estimated
- during training or evaluation are approximately equal to the average length of
- the strings in bits.
-
- The layer implements a flexible probability density model to estimate entropy,
- which is described in the appendix of the paper (please cite the paper if you
- use this code for scientific work):
-
- "Variational image compression with a scale hyperprior"
-
- Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston
-
- https://arxiv.org/abs/1802.01436
-
- The layer assumes that the input tensor is at least 2D, with a batch dimension
- at the beginning and a channel dimension as specified by `data_format`. The
- layer trains an independent probability density model for each channel, but
- assumes that across all other dimensions, the inputs are i.i.d. (independent
- and identically distributed). Because the entropy (and hence, average
- codelength) is a function of the densities, this assumption may have a direct
- effect on the compression performance.
-
- Because data compression always involves discretization, the outputs of the
- layer are generally only approximations of its inputs. During training,
- discretization is modeled using additive uniform noise to ensure
- differentiability. The entropies computed during training are differential
- entropies. During evaluation, the data is actually quantized, and the
- entropies are discrete (Shannon entropies). To make sure the approximated
- tensor values are good enough for practical purposes, the training phase must
- be used to balance the quality of the approximation with the entropy, by
- adding an entropy term to the training loss, as in the following example.
-
- Here, we use the entropy bottleneck to compress the latent representation of
- an autoencoder. The data vectors `x` in this case are 4D tensors in
- `'channels_last'` format (for example, 16x16 pixel grayscale images).
-
- The layer always produces exactly one auxiliary loss and one update op which
- are only significant for compression and decompression. To use the compression
- feature, the auxiliary loss must be minimized during or after training. After
- that, the update op must be executed at least once. Here, we simply attach
- them to the main training step.
-
- Training:
- ```
- # Build autoencoder.
- x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1])
- y = forward_transform(x)
- entropy_bottleneck = EntropyBottleneck()
- y_, likelihoods = entropy_bottleneck(y, training=True)
- x_ = backward_transform(y_)
-
- # Information content (= predicted codelength) in bits of each batch element
- # (note that taking the natural logarithm and dividing by `log(2)` is
- # equivalent to taking base-2 logarithms):
- bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2)
-
- # Squared difference of each batch element:
- squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3))
-
- # The loss is a weighted sum of mean squared error and entropy (average
- # information content), where the weight controls the trade-off between
- # approximation error and entropy.
- main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits)
-
- # Minimize loss and auxiliary loss, and execute update op.
- main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
- main_step = optimizer.minimize(main_loss)
- # 1e-2 is a good starting point for the learning rate of the auxiliary loss,
- # assuming Adam is used.
- aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)
- aux_step = optimizer.minimize(entropy_bottleneck.losses[0])
- step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
- ```
-
- Evaluation:
- ```
- # Build autoencoder.
- x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1])
- y = forward_transform(x)
- y_, likelihoods = EntropyBottleneck()(y, training=False)
- x_ = backward_transform(y_)
-
- # Information content (= predicted codelength) in bits of each batch element:
- bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2)
-
- # Squared difference of each batch element:
- squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3))
-
- # The loss is a weighted sum of mean squared error and entropy (average
- # information content), where the weight controls the trade-off between
- # approximation error and entropy.
- loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits)
- ```
-
- To be able to compress the bottleneck tensor and decompress it in a different
- session, or on a different machine, you need three items:
- - The compressed representations stored as strings.
- - The shape of the bottleneck for these string representations as a `Tensor`,
- as well as the number of channels of the bottleneck at graph construction
- time.
- - The checkpoint of the trained model that was used for compression. Note:
- It is crucial that the auxiliary loss produced by this layer is minimized
- during or after training, and that the update op is run after training and
- minimization of the auxiliary loss, but *before* the checkpoint is saved.
-
- Compression:
- ```
- x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1])
- y = forward_transform(x)
- strings = EntropyBottleneck().compress(y)
- shape = tf.shape(y)[1:]
- ```
-
- Decompression:
- ```
- strings = tf.placeholder(tf.string, shape=[None])
- shape = tf.placeholder(tf.int32, shape=[3])
- entropy_bottleneck = EntropyBottleneck(dtype=tf.float32)
- y_ = entropy_bottleneck.decompress(strings, shape, channels=5)
- x_ = backward_transform(y_)
- ```
- Here, we assumed that the tensor produced by the forward transform has 5
- channels.
-
- The above four use cases can also be implemented within the same session (i.e.
- on the same `EntropyBottleneck` instance), for testing purposes, etc., by
- calling the object more than once.
-
- Arguments:
- init_scale: Float. A scaling factor determining the initial width of the
- probability densities. This should be chosen big enough so that the
- range of values of the layer inputs roughly falls within the interval
- [`-init_scale`, `init_scale`] at the beginning of training.
- filters: An iterable of ints, giving the number of filters at each layer of
- the density model. Generally, the more filters and layers, the more
- expressive is the density model in terms of modeling more complicated
- distributions of the layer inputs. For details, refer to the paper
- referenced above. The default is `[3, 3, 3]`, which should be sufficient
- for most practical purposes.
- tail_mass: Float, between 0 and 1. The bottleneck layer automatically
- determines the range of input values that should be represented based on
- their frequency of occurrence. Values occurring in the tails of the
- distributions will be clipped to that range during compression.
- `tail_mass` determines the amount of probability mass in the tails which
- is cut off in the worst case. For example, the default value of `1e-9`
- means that at most 1 in a billion input samples will be clipped to the
- range.
- optimize_integer_offset: Boolean. Typically, the input values of this layer
- are floats, which means that quantization during evaluation can be
- performed with an arbitrary offset. By default, the layer determines that
- offset automatically. In special situations, such as when it is known that
- the layer will receive only full integer values during evaluation, it can
- be desirable to set this argument to `False` instead, in order to always
- quantize to full integer values.
- likelihood_bound: Float. If positive, the returned likelihood values are
- ensured to be greater than or equal to this value. This prevents very
- large gradients with a typical entropy loss (defaults to 1e-9).
- range_coder_precision: Integer, between 1 and 16. The precision of the range
- coder used for compression and decompression. This trades off computation
- speed with compression efficiency, where 16 is the slowest but most
- efficient setting. Choosing lower values may increase the average
- codelength slightly compared to the estimated entropies.
- data_format: Either `'channels_first'` or `'channels_last'` (default).
- trainable: Boolean. Whether the layer should be trained.
- name: String. The name of the layer.
- dtype: Default dtype of the layer's parameters (default of `None` means use
- the type of the first input).
-
- Read-only properties:
- init_scale: See above.
- filters: See above.
- tail_mass: See above.
- optimize_integer_offset: See above.
- likelihood_bound: See above.
- range_coder_precision: See above.
- data_format: See above.
- name: String. See above.
- dtype: See above.
- trainable_variables: List of trainable variables.
- non_trainable_variables: List of non-trainable variables.
- variables: List of all variables of this layer, trainable and non-trainable.
- updates: List of update ops of this layer. Always contains exactly one
- update op, which must be run once after the last training step, before
- `compress` or `decompress` is used.
- losses: List of losses added by this layer. Always contains exactly one
- auxiliary loss, which must be added to the training loss.
-
- Mutable properties:
- trainable: Boolean. Whether the layer should be trained.
- input_spec: Optional `InputSpec` object specifying the constraints on inputs
- that can be accepted by the layer.
- """
-
- def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9,
- optimize_integer_offset=True, likelihood_bound=1e-9,
- range_coder_precision=16, data_format="channels_last", **kwargs):
- super(EntropyBottleneck, self).__init__(**kwargs)
- self._init_scale = float(init_scale)
- self._filters = tuple(int(f) for f in filters)
- self._tail_mass = float(tail_mass)
- if not 0 < self.tail_mass < 1:
- raise ValueError(
- "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass))
- self._optimize_integer_offset = bool(optimize_integer_offset)
- self._likelihood_bound = float(likelihood_bound)
- self._range_coder_precision = int(range_coder_precision)
- self._data_format = data_format
- self._channel_axis(2) # trigger ValueError early
- self.input_spec = base_layer.InputSpec(min_ndim=2)
-
- @property
- def init_scale(self):
- return self._init_scale
-
- @property
- def filters(self):
- return self._filters
-
- @property
- def tail_mass(self):
- return self._tail_mass
-
- @property
- def optimize_integer_offset(self):
- return self._optimize_integer_offset
-
- @property
- def likelihood_bound(self):
- return self._likelihood_bound
-
- @property
- def range_coder_precision(self):
- return self._range_coder_precision
-
- @property
- def data_format(self):
- return self._data_format
-
- def _channel_axis(self, ndim):
- try:
- return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format]
- except KeyError:
- raise ValueError("Unsupported `data_format` for {} layer: {}.".format(
- self.__class__.__name__, self.data_format))
-
- def _logits_cumulative(self, inputs, stop_gradient):
- """Evaluate logits of the cumulative densities.
-
- Args:
- inputs: The values at which to evaluate the cumulative densities, expected
- to be a `Tensor` of shape `(channels, 1, batch)`.
- stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so
- that the gradient of the output with respect to the density model
- parameters is disconnected (the gradient with respect to `inputs` is
- left untouched).
-
- Returns:
- A `Tensor` of the same shape as `inputs`, containing the logits of the
- cumulative densities evaluated at the given inputs.
- """
- logits = inputs
-
- for i in range(len(self.filters) + 1):
- matrix = self._matrices[i]
- if stop_gradient:
- matrix = array_ops.stop_gradient(matrix)
- logits = math_ops.matmul(matrix, logits)
-
- bias = self._biases[i]
- if stop_gradient:
- bias = array_ops.stop_gradient(bias)
- logits += bias
-
- if i < len(self._factors):
- factor = self._factors[i]
- if stop_gradient:
- factor = array_ops.stop_gradient(factor)
- logits += factor * math_ops.tanh(logits)
-
- return logits
-
- def build(self, input_shape):
- """Builds the layer.
-
- Creates the variables for the network modeling the densities, creates the
- auxiliary loss estimating the median and tail quantiles of the densities,
- and then uses that to create the probability mass functions and the update
- op that produces the discrete cumulative density functions used by the range
- coder.
-
- Args:
- input_shape: Shape of the input tensor, used to get the number of
- channels.
-
- Raises:
- ValueError: if `input_shape` doesn't specify the length of the channel
- dimension.
- """
- input_shape = tensor_shape.TensorShape(input_shape)
- channel_axis = self._channel_axis(input_shape.ndims)
- channels = input_shape[channel_axis].value
- if channels is None:
- raise ValueError("The channel dimension of the inputs must be defined.")
- self.input_spec = base_layer.InputSpec(
- ndim=input_shape.ndims, axes={channel_axis: channels})
- filters = (1,) + self.filters + (1,)
- scale = self.init_scale ** (1 / (len(self.filters) + 1))
-
- # Create variables.
- self._matrices = []
- self._biases = []
- self._factors = []
- for i in range(len(self.filters) + 1):
- init = np.log(np.expm1(1 / scale / filters[i + 1]))
- matrix = self.add_variable(
- "matrix_{}".format(i), dtype=self.dtype,
- shape=(channels, filters[i + 1], filters[i]),
- initializer=init_ops.Constant(init))
- matrix = nn.softplus(matrix)
- self._matrices.append(matrix)
-
- bias = self.add_variable(
- "bias_{}".format(i), dtype=self.dtype,
- shape=(channels, filters[i + 1], 1),
- initializer=init_ops.RandomUniform(-.5, .5))
- self._biases.append(bias)
-
- if i < len(self.filters):
- factor = self.add_variable(
- "factor_{}".format(i), dtype=self.dtype,
- shape=(channels, filters[i + 1], 1),
- initializer=init_ops.Zeros())
- factor = math_ops.tanh(factor)
- self._factors.append(factor)
-
- # To figure out what range of the densities to sample, we need to compute
- # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we
- # can't take inverses of the cumulative directly, we make it an optimization
- # problem:
- # `quantiles = argmin(|logit(cumulative) - target|)`
- # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`.
- # Taking the logit (inverse of sigmoid) of the cumulative makes the
- # representation of the right target more numerically stable.
-
- # Numerically stable way of computing logits of `tail_mass / 2`
- # and `1 - tail_mass / 2`.
- target = np.log(2 / self.tail_mass - 1)
- # Compute lower and upper tail quantile as well as median.
- target = constant_op.constant([-target, 0, target], dtype=self.dtype)
-
- def quantiles_initializer(shape, dtype=None, partition_info=None):
- del partition_info # unused
- assert tuple(shape[1:]) == (1, 3)
- init = constant_op.constant(
- [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype)
- return array_ops.tile(init, (shape[0], 1, 1))
-
- quantiles = self.add_variable(
- "quantiles", shape=(channels, 1, 3), dtype=self.dtype,
- initializer=quantiles_initializer)
- logits = self._logits_cumulative(quantiles, stop_gradient=True)
- loss = math_ops.reduce_sum(abs(logits - target))
- self.add_loss(loss, inputs=None)
-
- # Save medians for `call`, `compress`, and `decompress`.
- self._medians = quantiles[:, :, 1:2]
- if not self.optimize_integer_offset:
- self._medians = math_ops.round(self._medians)
-
- # Largest distance observed between lower tail quantile and median,
- # or between median and upper tail quantile.
- minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1])
- maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians)
- minmax = math_ops.maximum(minima, maxima)
- minmax = math_ops.ceil(minmax)
- minmax = math_ops.maximum(minmax, 1)
-
- # Sample the density up to `minmax` around the median.
- samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype)
- samples += self._medians
-
- half = constant_op.constant(.5, dtype=self.dtype)
- # We strip the sigmoid from the end here, so we can use the special rule
- # below to only compute differences in the left tail of the sigmoid.
- # This increases numerical stability (see explanation in `call`).
- lower = self._logits_cumulative(samples - half, stop_gradient=True)
- upper = self._logits_cumulative(samples + half, stop_gradient=True)
- # Flip signs if we can move more towards the left tail of the sigmoid.
- sign = -math_ops.sign(math_ops.add_n([lower, upper]))
- pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower))
- # Add tail masses to first and last bin of pmf, as we clip values for
- # compression, meaning that out-of-range values get mapped to these bins.
- pmf = array_ops.concat([
- math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]),
- pmf[:, 0, 1:-1],
- math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]),
- ], axis=-1)
- self._pmf = pmf
-
- cdf = coder_ops.pmf_to_quantized_cdf(
- pmf, precision=self.range_coder_precision)
- def cdf_getter(*args, **kwargs):
- del args, kwargs # ignored
- return variable_scope.get_variable(
- "quantized_cdf", dtype=dtypes.int32, initializer=cdf,
- trainable=False, validate_shape=False, collections=())
- # Need to provide a fake shape here since add_variable insists on it.
- self._quantized_cdf = self.add_variable(
- "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32,
- getter=cdf_getter, trainable=False)
-
- update_op = state_ops.assign(
- self._quantized_cdf, cdf, validate_shape=False)
- self.add_update(update_op, inputs=None)
-
- super(EntropyBottleneck, self).build(input_shape)
-
- def call(self, inputs, training):
- """Pass a tensor through the bottleneck.
-
- Args:
- inputs: The tensor to be passed through the bottleneck.
- training: Boolean. If `True`, returns a differentiable approximation of
- the inputs, and their likelihoods under the modeled probability
- densities. If `False`, returns the quantized inputs and their
- likelihoods under the corresponding probability mass function. These
- quantities can't be used for training, as they are not differentiable,
- but represent actual compression more closely.
-
- Returns:
- values: `Tensor` with the same shape as `inputs` containing the perturbed
- or quantized input values.
- likelihood: `Tensor` with the same shape as `inputs` containing the
- likelihood of `values` under the modeled probability distributions.
-
- Raises:
- ValueError: if `inputs` has different `dtype` or number of channels than
- a previous set of inputs the model was invoked with earlier.
- """
- inputs = ops.convert_to_tensor(inputs)
- ndim = self.input_spec.ndim
- channel_axis = self._channel_axis(ndim)
- half = constant_op.constant(.5, dtype=self.dtype)
-
- # Convert to (channels, 1, batch) format by commuting channels to front
- # and then collapsing.
- order = list(range(ndim))
- order.pop(channel_axis)
- order.insert(0, channel_axis)
- values = array_ops.transpose(inputs, order)
- shape = array_ops.shape(values)
- values = array_ops.reshape(values, (shape[0], 1, -1))
-
- # Add noise or quantize.
- if training:
- noise = random_ops.random_uniform(array_ops.shape(values), -half, half)
- values = math_ops.add_n([values, noise])
- elif self.optimize_integer_offset:
- values = math_ops.round(values - self._medians) + self._medians
- else:
- values = math_ops.round(values)
-
- # Evaluate densities.
- # We can use the special rule below to only compute differences in the left
- # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1
- # for large x, 0 for small x. Subtracting two numbers close to 0 can be done
- # with much higher precision than subtracting two numbers close to 1.
- lower = self._logits_cumulative(values - half, stop_gradient=False)
- upper = self._logits_cumulative(values + half, stop_gradient=False)
- # Flip signs if we can move more towards the left tail of the sigmoid.
- sign = -math_ops.sign(math_ops.add_n([lower, upper]))
- sign = array_ops.stop_gradient(sign)
- likelihood = abs(
- math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower))
- if self.likelihood_bound > 0:
- likelihood_bound = constant_op.constant(
- self.likelihood_bound, dtype=self.dtype)
- # TODO(jballe): Override gradients.
- likelihood = math_ops.maximum(likelihood, likelihood_bound)
-
- # Convert back to input tensor shape.
- order = list(range(1, ndim))
- order.insert(channel_axis, 0)
- values = array_ops.reshape(values, shape)
- values = array_ops.transpose(values, order)
- likelihood = array_ops.reshape(likelihood, shape)
- likelihood = array_ops.transpose(likelihood, order)
-
- if not context.executing_eagerly():
- values_shape, likelihood_shape = self.compute_output_shape(inputs.shape)
- values.set_shape(values_shape)
- likelihood.set_shape(likelihood_shape)
-
- return values, likelihood
-
- def compress(self, inputs):
- """Compress inputs and store their binary representations into strings.
-
- Args:
- inputs: `Tensor` with values to be compressed.
-
- Returns:
- String `Tensor` vector containing the compressed representation of each
- batch element of `inputs`.
- """
- with ops.name_scope(self._name_scope()):
- inputs = ops.convert_to_tensor(inputs)
- if not self.built:
- # Check input assumptions set before layer building, e.g. input rank.
- self._assert_input_compatibility(inputs)
- if self.dtype is None:
- self._dtype = inputs.dtype.base_dtype.name
- self.build(inputs.shape)
-
- # Check input assumptions set after layer building, e.g. input shape.
- if not context.executing_eagerly():
- self._assert_input_compatibility(inputs)
-
- ndim = self.input_spec.ndim
- channel_axis = self._channel_axis(ndim)
- # Tuple of slices for expanding dimensions of tensors below.
- slices = ndim * [None] + [slice(None)]
- slices[channel_axis] = slice(None)
- slices = tuple(slices)
-
- # Expand dimensions of CDF to input dimensions, keeping the channels along
- # the right dimension.
- cdf = self._quantized_cdf[slices[1:]]
- num_levels = array_ops.shape(cdf)[-1] - 1
-
- # Bring inputs to the right range by centering the range on the medians.
- half = constant_op.constant(.5, dtype=self.dtype)
- medians = array_ops.squeeze(self._medians, [1, 2])
- offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians
- # Expand offsets to input dimensions and add to inputs.
- values = inputs + offsets[slices[:-1]]
-
- # Clip to range and cast to integers. Because we have added .5 above, and
- # all values are positive, the cast effectively implements rounding.
- values = math_ops.maximum(values, half)
- values = math_ops.minimum(
- values, math_ops.cast(num_levels, self.dtype) - half)
- values = math_ops.cast(values, dtypes.int16)
-
- def loop_body(tensor):
- return coder_ops.range_encode(
- tensor, cdf, precision=self.range_coder_precision)
- strings = functional_ops.map_fn(
- loop_body, values, dtype=dtypes.string, back_prop=False)
-
- if not context.executing_eagerly():
- strings.set_shape(inputs.shape[:1])
-
- return strings
-
- def decompress(self, strings, shape, channels=None):
- """Decompress values from their compressed string representations.
-
- Args:
- strings: A string `Tensor` vector containing the compressed data.
- shape: A `Tensor` vector of int32 type. Contains the shape of the tensor
- to be decompressed, excluding the batch dimension.
- channels: Integer. Specifies the number of channels statically. Needs only
- be set if the layer hasn't been built yet (i.e., this is the first input
- it receives).
-
- Returns:
- The decompressed `Tensor`. Its shape will be equal to `shape` prepended
- with the batch dimension from `strings`.
-
- Raises:
- ValueError: If the length of `shape` isn't available at graph construction
- time.
- """
- with ops.name_scope(self._name_scope()):
- strings = ops.convert_to_tensor(strings)
- shape = ops.convert_to_tensor(shape)
- if self.built:
- ndim = self.input_spec.ndim
- channel_axis = self._channel_axis(ndim)
- if channels is None:
- channels = self.input_spec.axes[channel_axis]
- else:
- if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1):
- raise ValueError("`shape` must be a vector with known length.")
- ndim = shape.shape[0].value + 1
- channel_axis = self._channel_axis(ndim)
- input_shape = ndim * [None]
- input_shape[channel_axis] = channels
- self.build(input_shape)
-
- # Tuple of slices for expanding dimensions of tensors below.
- slices = ndim * [None] + [slice(None)]
- slices[channel_axis] = slice(None)
- slices = tuple(slices)
-
- # Expand dimensions of CDF to input dimensions, keeping the channels along
- # the right dimension.
- cdf = self._quantized_cdf[slices[1:]]
- num_levels = array_ops.shape(cdf)[-1] - 1
-
- def loop_body(string):
- return coder_ops.range_decode(
- string, shape, cdf, precision=self.range_coder_precision)
- outputs = functional_ops.map_fn(
- loop_body, strings, dtype=dtypes.int16, back_prop=False)
- outputs = math_ops.cast(outputs, self.dtype)
-
- medians = array_ops.squeeze(self._medians, [1, 2])
- offsets = math_ops.cast(num_levels // 2, self.dtype) - medians
- outputs -= offsets[slices[:-1]]
-
- if not context.executing_eagerly():
- outputs_shape = ndim * [None]
- outputs_shape[0] = strings.shape[0]
- outputs_shape[channel_axis] = channels
- outputs.set_shape(outputs_shape)
-
- return outputs
-
- def visualize(self):
- """Multi-channel visualization of densities as images.
-
- Creates and returns an image summary visualizing the current probabilty
- density estimates. The image contains one row for each channel. Within each
- row, the pixel intensities are proportional to probability values, and each
- row is centered on the median of the corresponding distribution.
-
- Returns:
- The created image summary.
- """
- with ops.name_scope(self._name_scope()):
- image = self._pmf
- image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True)
- image = math_ops.cast(image + .5, dtypes.uint8)
- image = image[None, :, :, None]
- return summary.image("pmf", image, max_outputs=1)
-
- def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape)
- return input_shape, input_shape
diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py
deleted file mode 100644
index 798b0234eb..0000000000
--- a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py
+++ /dev/null
@@ -1,315 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests of EntropyBottleneck class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.coder.python.layers import entropybottleneck
-
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import gradient_descent
-
-
-class EntropyBottleneckTest(test.TestCase):
-
- def test_noise(self):
- # Tests that the noise added is uniform noise between -0.5 and 0.5.
- inputs = array_ops.placeholder(dtypes.float32, (None, 1))
- layer = entropybottleneck.EntropyBottleneck()
- noisy, _ = layer(inputs, training=True)
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- values = np.linspace(-50, 50, 100)[:, None]
- noisy, = sess.run([noisy], {inputs: values})
- self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49))
- self.assertAllClose(values, noisy, rtol=0, atol=.5)
-
- def test_quantization(self):
- # Tests that inputs are quantized to full integer values, even after
- # quantiles have been updated.
- inputs = array_ops.placeholder(dtypes.float32, (None, 1))
- layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False)
- quantized, _ = layer(inputs, training=False)
- opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
- self.assertTrue(len(layer.losses) == 1)
- step = opt.minimize(layer.losses[0])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(step)
- values = np.linspace(-50, 50, 100)[:, None]
- quantized, = sess.run([quantized], {inputs: values})
- self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6)
-
- def test_quantization_optimized_offset(self):
- # Tests that inputs are not quantized to full integer values after quantiles
- # have been updated. However, the difference between input and output should
- # be between -0.5 and 0.5, and the offset must be consistent.
- inputs = array_ops.placeholder(dtypes.float32, (None, 1))
- layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True)
- quantized, _ = layer(inputs, training=False)
- opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
- self.assertTrue(len(layer.losses) == 1)
- step = opt.minimize(layer.losses[0])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(step)
- values = np.linspace(-50, 50, 100)[:, None]
- quantized, = sess.run([quantized], {inputs: values})
- self.assertAllClose(values, quantized, rtol=0, atol=.5)
- diff = np.ravel(np.around(values) - quantized) % 1
- self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6)
- self.assertNotEqual(diff[0], 0)
-
- def test_codec(self):
- # Tests that inputs are compressed and decompressed correctly, and quantized
- # to full integer values, even after quantiles have been updated.
- inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_last", init_scale=60,
- optimize_integer_offset=False)
- bitstrings = layer.compress(inputs)
- decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
- opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
- self.assertTrue(len(layer.losses) == 1)
- step = opt.minimize(layer.losses[0])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(step)
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- values = np.linspace(-50, 50, 100)[None, :, None]
- decoded, = sess.run([decoded], {inputs: values})
- self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6)
-
- def test_codec_optimized_offset(self):
- # Tests that inputs are compressed and decompressed correctly, and not
- # quantized to full integer values after quantiles have been updated.
- # However, the difference between input and output should be between -0.5
- # and 0.5, and the offset must be consistent.
- inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_last", init_scale=60,
- optimize_integer_offset=True)
- bitstrings = layer.compress(inputs)
- decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
- opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
- self.assertTrue(len(layer.losses) == 1)
- step = opt.minimize(layer.losses[0])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(step)
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- values = np.linspace(-50, 50, 100)[None, :, None]
- decoded, = sess.run([decoded], {inputs: values})
- self.assertAllClose(values, decoded, rtol=0, atol=.5)
- diff = np.ravel(np.around(values) - decoded) % 1
- self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6)
- self.assertNotEqual(diff[0], 0)
-
- def test_codec_clipping(self):
- # Tests that inputs are compressed and decompressed correctly, and clipped
- # to the expected range.
- inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_last", init_scale=40)
- bitstrings = layer.compress(inputs)
- decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- values = np.linspace(-50, 50, 100)[None, :, None]
- decoded, = sess.run([decoded], {inputs: values})
- expected = np.clip(np.around(values), -40, 40)
- self.assertAllClose(expected, decoded, rtol=0, atol=1e-6)
-
- def test_channels_last(self):
- # Test the layer with more than one channel and multiple input dimensions,
- # with the channels in the last dimension.
- inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2))
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_last", init_scale=50)
- noisy, _ = layer(inputs, training=True)
- quantized, _ = layer(inputs, training=False)
- bitstrings = layer.compress(inputs)
- decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- values = 5 * np.random.normal(size=(7, 5, 3, 2))
- noisy, quantized, decoded = sess.run(
- [noisy, quantized, decoded], {inputs: values})
- self.assertAllClose(values, noisy, rtol=0, atol=.5)
- self.assertAllClose(values, quantized, rtol=0, atol=.5)
- self.assertAllClose(values, decoded, rtol=0, atol=.5)
-
- def test_channels_first(self):
- # Test the layer with more than one channel and multiple input dimensions,
- # with the channel dimension right after the batch dimension.
- inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None))
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_first", init_scale=50)
- noisy, _ = layer(inputs, training=True)
- quantized, _ = layer(inputs, training=False)
- bitstrings = layer.compress(inputs)
- decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- values = 5 * np.random.normal(size=(2, 3, 5, 7))
- noisy, quantized, decoded = sess.run(
- [noisy, quantized, decoded], {inputs: values})
- self.assertAllClose(values, noisy, rtol=0, atol=.5)
- self.assertAllClose(values, quantized, rtol=0, atol=.5)
- self.assertAllClose(values, decoded, rtol=0, atol=.5)
-
- def test_compress(self):
- # Test compression and decompression, and produce test data for
- # `test_decompress`. If you set the constant at the end to `True`, this test
- # will fail and the log will contain the new test data.
- inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10))
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_first", filters=(), init_scale=2)
- bitstrings = layer.compress(inputs)
- decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5
- bitstrings, quantized_cdf, decoded = sess.run(
- [bitstrings, layer._quantized_cdf, decoded], {inputs: values})
- self.assertAllClose(values, decoded, rtol=0, atol=.5)
- # Set this constant to `True` to log new test data for `test_decompress`.
- if False: # pylint:disable=using-constant-test
- assert False, (bitstrings, quantized_cdf, decoded)
-
- # Data generated by `test_compress`.
- # pylint:disable=g-inconsistent-quotes,bad-whitespace
- bitstrings = np.array([
- b'\x1e\xbag}\xc2\xdaN\x8b\xbd.',
- b'\x8dF\xf0%\x1cv\xccllW'
- ], dtype=object)
-
- quantized_cdf = np.array([
- [ 0, 15636, 22324, 30145, 38278, 65536],
- [ 0, 19482, 26927, 35052, 42904, 65535],
- [ 0, 21093, 28769, 36919, 44578, 65536]
- ], dtype=np.int32)
-
- expected = np.array([
- [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.],
- [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.],
- [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]],
- [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.],
- [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.],
- [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]]
- ], dtype=np.float32)
- # pylint:enable=g-inconsistent-quotes,bad-whitespace
-
- def test_decompress(self):
- # Test that decompression of values compressed with a previous version
- # works, i.e. that the file format doesn't change across revisions.
- bitstrings = array_ops.placeholder(dtypes.string)
- input_shape = array_ops.placeholder(dtypes.int32)
- quantized_cdf = array_ops.placeholder(dtypes.int32)
- layer = entropybottleneck.EntropyBottleneck(
- data_format="channels_first", filters=(), dtype=dtypes.float32)
- layer.build(self.expected.shape)
- layer._quantized_cdf = quantized_cdf
- decoded = layer.decompress(bitstrings, input_shape[1:])
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- decoded, = sess.run([decoded], {
- bitstrings: self.bitstrings, input_shape: self.expected.shape,
- quantized_cdf: self.quantized_cdf})
- self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6)
-
- def test_build_decompress(self):
- # Test that layer can be built when `decompress` is the first call to it.
- bitstrings = array_ops.placeholder(dtypes.string)
- input_shape = array_ops.placeholder(dtypes.int32, shape=[3])
- layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
- layer.decompress(bitstrings, input_shape[1:], channels=5)
- self.assertTrue(layer.built)
-
- def test_pmf_normalization(self):
- # Test that probability mass functions are normalized correctly.
- layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
- layer.build((None, 10))
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- pmf, = sess.run([layer._pmf])
- self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6)
-
- def test_visualize(self):
- # Test that summary op can be constructed.
- layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
- layer.build((None, 10))
- summary = layer.visualize()
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- sess.run([summary])
-
- def test_normalization(self):
- # Test that densities are normalized correctly.
- inputs = array_ops.placeholder(dtypes.float32, (None, 1))
- layer = entropybottleneck.EntropyBottleneck(filters=(2,))
- _, likelihood = layer(inputs, training=True)
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- x = np.repeat(np.arange(-200, 201), 1000)[:, None]
- likelihood, = sess.run([likelihood], {inputs: x})
- self.assertEqual(x.shape, likelihood.shape)
- integral = np.sum(likelihood) * .001
- self.assertAllClose(1, integral, rtol=0, atol=1e-4)
-
- def test_entropy_estimates(self):
- # Test that entropy estimates match actual range coding.
- inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
- layer = entropybottleneck.EntropyBottleneck(
- filters=(2, 3), data_format="channels_last")
- _, likelihood = layer(inputs, training=True)
- diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2)
- _, likelihood = layer(inputs, training=False)
- disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2)
- bitstrings = layer.compress(inputs)
- with self.test_session() as sess:
- sess.run(variables.global_variables_initializer())
- self.assertTrue(len(layer.updates) == 1)
- sess.run(layer.updates[0])
- diff_entropy, disc_entropy, bitstrings = sess.run(
- [diff_entropy, disc_entropy, bitstrings],
- {inputs: np.random.normal(size=(1, 10000, 1))})
- codelength = 8 * sum(len(bitstring) for bitstring in bitstrings)
- self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0)
- self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0)
- self.assertGreater(codelength, disc_entropy)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 2de1a79d28..00f5b74c33 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -210,7 +210,9 @@ py_test(
srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -239,7 +241,7 @@ cuda_py_test(
tags = [
"manual",
"no_oss",
- "no_windows_gpu" +
+ "no_windows_gpu",
"notap",
],
)
@@ -431,8 +433,8 @@ py_test(
tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
+ ":stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -442,6 +444,16 @@ py_test(
],
)
+py_library(
+ name = "stats_dataset_test_base",
+ srcs = ["stats_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "threadpool_dataset_ops_test",
size = "small",
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 30a993b1f7..77148aceec 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
def _read_vars(self, model_dir):
"""Returns (global_step, latest_feature)."""
with ops.Graph().as_default() as g:
- ckpt_path = saver_lib.latest_checkpoint(model_dir)
+ ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
meta_filename = ckpt_path + '.meta'
saver_lib.import_meta_graph(meta_filename)
saver = saver_lib.Saver()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index d8156dc9c7..2427935c73 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -19,7 +19,9 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -160,5 +162,34 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(get_next)
+class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testLatencyStatsOptimization(self):
+
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.from_tensors(1).apply(
+ optimization.assert_next(
+ ["LatencyStats", "Map", "LatencyStats", "Prefetch",
+ "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
+ optimization.optimize(["latency_all_edges"])).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertEqual(1 * 1, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_TensorDataset/_1", 1)
+ self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
+ 1)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_PrefetchDataset/_6", 1)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 2da6131e8e..d66305d732 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -907,6 +907,42 @@ class CopyToDeviceTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testIteratorGetNextAsOptionalOnGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(3)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_elem = iterator_ops.get_next_as_optional(iterator)
+ elem_has_value_t = next_elem.has_value()
+ elem_value_t = next_elem.get_value()
+
+ with self.test_session() as sess:
+ # Before initializing the iterator, evaluating the optional fails with
+ # a FailedPreconditionError.
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_has_value_t)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_value_t)
+
+ # For each element of the dataset, assert that the optional evaluates to
+ # the expected value.
+ sess.run(iterator.initializer)
+ for i in range(3):
+ elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
+ self.assertTrue(elem_has_value)
+ self.assertEqual(i, elem_value)
+
+ # After exhausting the iterator, `next_elem.has_value()` will evaluate to
+ # false, and attempting to get the value will fail.
+ for _ in range(2):
+ self.assertFalse(sess.run(elem_has_value_t))
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(elem_value_t)
+
class MultiDeviceIteratorTest(test.TestCase):
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 851a33dfc8..15b342d30f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -173,15 +173,23 @@ class ReadBatchFeaturesTest(
for num_epochs in [1, 10]:
with ops.Graph().as_default():
# Basic test: read from file 0.
- self.outputs = self.make_batch_feature(
+ outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
- for _, tensor in self.outputs.items():
+ for _, tensor in outputs.items():
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0], num_epochs=None, batch_size=32)
+ for shape, clazz in zip(nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_classes)):
+ if issubclass(clazz, ops.Tensor):
+ self.assertEqual(32, shape[0])
+
class MakeCsvDatasetTest(test.TestCase):
@@ -795,6 +803,16 @@ class MakeCsvDatasetTest(test.TestCase):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
+ def testIndefiniteRepeatShapeInference(self):
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+ dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
class MakeTFRecordDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase):
@@ -1002,5 +1020,12 @@ class MakeTFRecordDatasetTest(
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
seed=21345)
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 3c3f23f9a9..7b9ea191a4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -56,6 +56,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
index a0a1100893..1b6059ccbc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -19,6 +19,8 @@ from __future__ import print_function
import os
+from absl.testing import parameterized
+
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
@@ -26,7 +28,8 @@ from tensorflow.python.platform import test
class CacheDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
def setUp(self):
self.range_size = 10
@@ -34,88 +37,123 @@ class CacheDatasetSerializationTest(
self.num_outputs = self.range_size * self.num_repeats
self.cache_file_prefix = 'test'
- def ds_fn(self):
- return dataset_ops.Dataset.range(self.range_size).cache(
- os.path.join(self.get_temp_dir(),
- self.cache_file_prefix)).repeat(self.num_repeats)
+ def make_dataset_fn(self, is_memory):
+ if is_memory:
+ filename = ''
+ else:
+ filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)
+
+ def ds_fn():
+ return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
+ self.num_repeats)
+
+ return ds_fn
def expected_outputs(self):
return list(range(self.range_size)) * self.num_repeats
- def testCheckpointBeforeOneEpoch(self):
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
# Generate 5 entries from iterator and save checkpoint.
- outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False)
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(5))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
outputs.extend(
self.gen_outputs(
- self.ds_fn, [],
+ ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, self.expected_outputs())
- def testCheckpointBeforeOneEpochThenRunFewSteps(self):
- # Generate 8 entries from iterator but save checkpoint after producing
- # 5.
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 8 entries from iterator but save checkpoint after producing 5.
outputs = self.gen_outputs(
- self.ds_fn, [5],
- 8,
- verify_exhausted=False,
- save_checkpoint_at_end=False)
+ ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, range(8))
- # Restoring from checkpoint and running GetNext should return a
- # `AlreadExistsError` now because the lockfile already exists.
- with self.assertRaises(errors.AlreadyExistsError):
- self.gen_outputs(
- self.ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False)
+ if is_memory:
+ outputs = outputs[:5]
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+ else:
+ # Restoring from checkpoint and running GetNext should return
+ # `AlreadExistsError` now because the lockfile already exists.
+ with self.assertRaises(errors.AlreadyExistsError):
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointAfterOneEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
- def testCheckpointAfterOneEpoch(self):
# Generate 15 entries from iterator and save checkpoint.
- outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False)
+ outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
# Restore from checkpoint and produce the rest of the elements from the
# iterator.
outputs.extend(
self.gen_outputs(
- self.ds_fn, [],
+ ds_fn, [],
self.num_outputs - 15,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, self.expected_outputs())
- def testCheckpointAfterOneEpochThenRunFewSteps(self):
- # Generate 18 entries from iterator but save checkpoint after producing
- # 15.
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 18 entries from iterator but save checkpoint after producing 15.
outputs = self.gen_outputs(
- self.ds_fn, [15],
- 18,
- verify_exhausted=False,
- save_checkpoint_at_end=False)
+ ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
- self.ds_fn, [],
+ ds_fn, [],
self.num_outputs - 15,
ckpt_saved=True,
verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
- def testCheckpointBeforeOneEpochButRunCompleteEpoch(self):
- # Generate 13 entries from iterator but save checkpoint after producing
- # 5.
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 13 entries from iterator but save checkpoint after producing 5.
outputs = self.gen_outputs(
- self.ds_fn, [5],
- 13,
- verify_exhausted=False,
- save_checkpoint_at_end=False)
+ ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
# Since we ran for more than one epoch, the cache was completely written.
@@ -124,65 +162,90 @@ class CacheDatasetSerializationTest(
# been completely written.
outputs = list(range(5)) + self.gen_outputs(
- self.ds_fn, [],
+ ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
- def testCheckpointUnusedWriterIterator(self):
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointUnusedWriterIterator(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
# Checkpoint before get_next is called even once.
- outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False)
+ outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
self.assertSequenceEqual(outputs, [])
outputs = self.gen_outputs(
- self.ds_fn, [],
- self.num_outputs,
- ckpt_saved=True,
- verify_exhausted=False)
+ ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
- def testCheckpointUnusedMidwayWriterIterator(self):
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
# Produce 5 elements and checkpoint.
- outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False)
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(5))
# Restore from checkpoint, then produce no elements and checkpoint.
outputs.extend(
- self.gen_outputs(
- self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
+ self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
self.assertSequenceEqual(outputs, range(5))
# Restore from checkpoint and produce rest of the elements.
outputs.extend(
self.gen_outputs(
- self.ds_fn, [],
+ ds_fn, [],
self.num_outputs - 5,
ckpt_saved=True,
verify_exhausted=False))
self.assertSequenceEqual(outputs, list(range(10)) * 3)
- def testUnusedCheckpointError(self):
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testUnusedCheckpointError(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
# Produce 5 elements and save ckpt.
- outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False)
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
self.assertSequenceEqual(outputs, range(5))
- # Since the complete cache has not been written, a new iterator which does
- # not restore the checkpoint will throw an error since there is a partial
- # cache shard.
- with self.assertRaises(errors.AlreadyExistsError):
+ if is_memory:
outputs = self.gen_outputs(
- self.ds_fn, [], self.num_outputs, verify_exhausted=False)
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+ else:
+ # Since the complete cache has not been written, a new iterator which does
+ # not restore the checkpoint will throw an error since there is a partial
+ # cache shard.
+ with self.assertRaises(errors.AlreadyExistsError):
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testIgnoreCheckpointIfCacheWritten(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
- def testIgnoreCheckpointIfCacheWritten(self):
# Produce 15 elements and save ckpt. This will write the complete cache.
- outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False)
+ outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
# Build the iterator again but do not restore from ckpt. Since the cache
# has already been written we should be able to use it.
outputs = self.gen_outputs(
- self.ds_fn, [], self.num_outputs, verify_exhausted=False)
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
self.assertSequenceEqual(outputs, list(range(10)) * 3)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
index 393f08850b..3ed4dfb729 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
@@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase):
return os.path.join(self.get_temp_dir(), "iterator")
def _latest_ckpt(self):
- return saver_lib.latest_checkpoint(self.get_temp_dir())
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
def _save(self, sess, saver):
saver.save(sess, self._ckpt_path())
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index b4945685c1..a41d21f8c1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -29,28 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class StatsDatasetTestBase(test.TestCase):
-
- def _assertSummaryHasCount(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.num)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasSum(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.sum)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
-
-class StatsDatasetTest(StatsDatasetTestBase):
+class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testBytesProduced(self):
stats_aggregator = stats_ops.StatsAggregator()
@@ -197,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase):
class FeatureStatsDatasetTest(
- StatsDatasetTestBase,
+ stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
def testFeaturesStats(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
new file mode 100644
index 0000000000..9a13acf8f0
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing the input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.python.platform import test
+
+
+class StatsDatasetTestBase(test.TestCase):
+ """Base class for testing statistics gathered in `StatsAggregator`."""
+
+ def _assertSummaryHasCount(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.num)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasSum(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.sum)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 42fc20ec01..4835c4e5bd 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -31,7 +31,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -439,48 +438,6 @@ def unbatch():
return _apply_fn
-def _filter_irregular_batches(batch_size):
- """Transformation that filters out batches that are not of size batch_size."""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- tensor_batch_size = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
-
- flattened = _RestructuredDataset(
- dataset,
- tuple(nest.flatten(dataset.output_types)),
- output_classes=tuple(nest.flatten(dataset.output_classes)))
-
- def _predicate(*xs):
- """Return `True` if this element is a full batch."""
- # Extract the dynamic batch size from the first component of the flattened
- # batched element.
- first_component = xs[0]
- first_component_batch_size = array_ops.shape(
- first_component, out_type=dtypes.int64)[0]
-
- return math_ops.equal(first_component_batch_size, tensor_batch_size)
-
- filtered = flattened.filter(_predicate)
-
- maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
-
- def _set_first_dimension(shape):
- return shape.merge_with(
- tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
-
- known_shapes = nest.map_structure(_set_first_dimension,
- dataset.output_shapes)
- return _RestructuredDataset(
- filtered,
- dataset.output_types,
- known_shapes,
- output_classes=dataset.output_classes)
-
- return _apply_fn
-
-
@deprecation.deprecated(
None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.")
def batch_and_drop_remainder(batch_size):
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 0d71be6601..d2c1d0d362 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
@@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
- latest_checkpoint_path = saver_lib.latest_checkpoint(
+ latest_checkpoint_path = checkpoint_management.latest_checkpoint(
self._checkpoint_saver_hook._checkpoint_dir,
latest_filename=self._latest_filename)
if latest_checkpoint_path:
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index f018dd02e6..14d69f8d5b 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -286,11 +286,14 @@ def make_tf_record_dataset(
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+ # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ drop_final_batch = drop_final_batch or num_epochs is None
+
if parser_fn is None:
- if drop_final_batch:
- dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
- else:
- dataset = dataset.batch(batch_size)
+ dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
else:
# TODO(josh11b): if num_parallel_parser_calls is None, use some function
# of num cores instead of map_and_batch's default behavior of one batch.
@@ -493,8 +496,13 @@ def make_csv_dataset(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
# Apply batch before map for perf, because map has high overhead relative
- # to the size of the computation in each map
- dataset = dataset.batch(batch_size=batch_size)
+ # to the size of the computation in each map.
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(batch_size=batch_size,
+ drop_remainder=num_epochs is None)
dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls)
dataset = dataset.prefetch(prefetch_buffer_size)
@@ -772,10 +780,12 @@ def make_batched_features_dataset(file_pattern,
dataset = dataset.apply(stats_ops.feature_stats("record_stats"))
- if drop_final_batch:
- dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
- else:
- dataset = dataset.batch(batch_size)
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(
+ batch_size, drop_remainder=drop_final_batch or num_epochs is None)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
dataset = dataset.map(
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index 1126f76f58..d3628d480d 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -25,10 +25,13 @@ py_library(
srcs = ["__init__.py"],
visibility = ["//tensorflow:internal"],
deps = [
+ "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy",
"//tensorflow/contrib/distribute/python:cross_tower_ops",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
"//tensorflow/contrib/distribute/python:monitor",
+ "//tensorflow/contrib/distribute/python:multi_worker_strategy",
"//tensorflow/contrib/distribute/python:one_device_strategy",
+ "//tensorflow/contrib/distribute/python:parameter_server_strategy",
"//tensorflow/contrib/distribute/python:step_fn",
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 2e2c3be853..9123ca749b 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -19,10 +19,13 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy
from tensorflow.contrib.distribute.python.cross_tower_ops import *
from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
+from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy
from tensorflow.contrib.distribute.python.monitor import Monitor
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
+from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.training.distribute import *
@@ -32,11 +35,14 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'AllReduceCrossTowerOps',
+ 'CollectiveAllReduceStrategy',
'CrossTowerOps',
'DistributionStrategy',
'MirroredStrategy',
+ 'MultiWorkerMirroredStrategy',
'Monitor',
'OneDeviceStrategy',
+ 'ParameterServerStrategy',
'ReductionToOneDeviceCrossTowerOps',
'Step',
'StandardInputStep',
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index cbe741de5a..d9e66ddac0 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -134,6 +134,24 @@ py_library(
)
py_library(
+ name = "collective_all_reduce_strategy",
+ srcs = ["collective_all_reduce_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":cross_tower_ops",
+ ":cross_tower_utils",
+ ":mirrored_strategy",
+ ":values",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:collective_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
name = "strategy_test_lib",
testonly = 1,
srcs = ["strategy_test_lib.py"],
@@ -293,11 +311,11 @@ py_library(
],
deps = [
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
"//tensorflow/python:distributed_framework_test_lib",
- "//tensorflow/python:platform",
"//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python/eager:test",
+ "//tensorflow/python/estimator:run_config",
+ "//third_party/py/numpy",
],
)
@@ -318,8 +336,7 @@ py_library(
deps = [
":one_device_strategy",
":values",
- "//tensorflow/contrib/tpu",
- "//tensorflow/contrib/tpu:tpu_py",
+ "//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
@@ -327,6 +344,37 @@ py_library(
],
)
+py_test(
+ name = "collective_all_reduce_strategy_test",
+ srcs = ["collective_all_reduce_strategy_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ ":collective_all_reduce_strategy",
+ ":combinations",
+ ":cross_tower_utils",
+ ":multi_worker_test_base",
+ ":strategy_test_lib",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:run_config",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
py_library(
name = "minimize_loss_test_lib",
testonly = 1,
@@ -497,8 +545,11 @@ py_library(
"//tensorflow/contrib/all_reduce:all_reduce_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:collective_ops",
+ "//tensorflow/python:device",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
],
)
@@ -533,7 +584,9 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],
@@ -541,6 +594,7 @@ py_library(
cuda_py_test(
name = "cross_tower_ops_test",
+ size = "large",
srcs = ["cross_tower_ops_test.py"],
additional_deps = [
":combinations",
@@ -555,7 +609,6 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
- shard_count = 15,
tags = [
"multi_and_single_gpu",
"no_pip",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
new file mode 100644
index 0000000000..9afcaecf78
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -0,0 +1,205 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import os
+
+from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.training import server_lib
+
+
+# TODO(yuefengz): move this function to a common util file.
+def _normalize_cluster_spec(cluster_spec):
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ return server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ return cluster_spec
+
+
+# TODO(yuefengz): shard the dataset.
+# TODO(yuefengz): support in-graph replication.
+# TODO(yuefengz): it only works with a cluster without a chief node, maybe
+# support chief node?
+class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
+ """Distribution strategy that uses collective ops for all-reduce.
+
+ It is similar to the MirroredStrategy but it uses collective ops for
+ reduction. It currently only works for between-graph replication and its
+ reduction will reduce across all workers.
+ """
+
+ def __init__(self,
+ num_gpus_per_worker=0,
+ cluster_spec=None,
+ task_type="worker",
+ task_id=0):
+ """Initializes the object.
+
+ Args:
+ num_gpus_per_worker: number of local GPUs or GPUs per worker.
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type, such as "worker".
+ task_id: the current task id.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec`.
+ """
+ self._num_gpus_per_worker = num_gpus_per_worker
+ self._initialize(cluster_spec, task_type, task_id)
+
+ def _initialize(self, cluster_spec, task_type, task_id):
+ if task_type not in ["chief", "worker"]:
+ raise ValueError(
+ "Unrecognized task_type: %r, valid task types are: \"chief\", "
+ "\"worker\"." % task_type)
+ if cluster_spec:
+ self._cluster_spec = _normalize_cluster_spec(cluster_spec)
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ num_workers = len(self._cluster_spec.as_dict().get(task_type, []))
+ if "chief" in self._cluster_spec.as_dict():
+ num_workers += 1
+ if not num_workers:
+ raise ValueError("`task_type` shoud be in `cluster_spec`.")
+
+ # TODO(yuefengz): create a utility to infer chief.
+ if "chief" in self._cluster_spec.as_dict() and task_type == "chief":
+ assert task_id == 0
+ self._is_chief = True
+ else:
+ assert task_type == "worker"
+ self._is_chief = task_id == 0
+ else:
+ self._cluster_spec = None
+ self._is_chief = True
+ worker_device = ""
+ num_workers = 1
+ self._num_workers = num_workers
+
+ if self._num_gpus_per_worker:
+ local_devices = [
+ "%s/device:GPU:%d" % (worker_device, i)
+ for i in range(self._num_gpus_per_worker)
+ ]
+ else:
+ local_devices = [worker_device]
+
+ self._collective_keys = cross_tower_utils.CollectiveKeys()
+ super(CollectiveAllReduceStrategy, self).__init__(
+ devices=local_devices,
+ cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
+ num_workers=num_workers,
+ num_gpus_per_worker=self._num_gpus_per_worker,
+ collective_keys=self._collective_keys))
+
+ # Add a default device so that ops without specified devices will not end up
+ # on other workers.
+ if cluster_spec:
+ self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id)
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+ group_size = len(devices) * self._num_workers
+ group_key = self._collective_keys.get_group_key(self._devices)
+
+ def _real_mirrored_creator(devices, *args, **kwargs):
+ """Creates one MirroredVariable on the current worker."""
+ index = {}
+ collective_instance_key = self._collective_keys.get_instance_key(
+ key_id=kwargs["name"])
+ if "initial_value" not in kwargs:
+ raise ValueError("Initial value must be specified.")
+ initial_value = kwargs["initial_value"]
+ if callable(initial_value):
+ initial_value_fn = initial_value
+ else:
+ initial_value_fn = lambda: initial_value
+
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+
+ # The initial value fn makes sure variables all initialized to
+ # same values. The first device of the chief worker will send their
+ # variable values to other devices and other workers.
+ def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
+ with ops.device(device):
+ initial_value = initial_value_fn()
+ assert not callable(initial_value)
+ initial_value = ops.convert_to_tensor(initial_value)
+
+ if self._is_chief and index == 0:
+ bcast_send = collective_ops.broadcast_send(
+ initial_value, initial_value.shape, initial_value.dtype,
+ group_size, group_key, collective_instance_key)
+ with ops.control_dependencies([bcast_send]):
+ return array_ops.identity(initial_value)
+ else:
+ return collective_ops.broadcast_recv(
+ initial_value.shape, initial_value.dtype, group_size,
+ group_key, collective_instance_key)
+
+ kwargs["initial_value"] = _overridden_initial_value_fn
+
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ v = next_creator(*args, **kwargs)
+
+ assert not isinstance(v, values.DistributedVariable)
+ index[d] = v
+ return index
+
+ # pylint: disable=protected-access
+ return mirrored_strategy._create_mirrored_variable(
+ devices, _real_mirrored_creator, *args, **kwargs)
+
+ def configure(self, session_config=None):
+ # Use TF_CONFIG to get the cluster spec and the current job.
+ if not self._cluster_spec:
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", "worker")
+ task_id = int(task_env.get("index", "0"))
+ else:
+ task_type = "worker"
+ task_id = 0
+
+ if cluster_spec:
+ self._initialize(cluster_spec, task_type, task_id)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
new file mode 100644
index 0000000000..b5e54e3b7d
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -0,0 +1,217 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for CollectiveAllReduceStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import multi_worker_test_base
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DistributedCollectiveAllReduceStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+
+ collective_key_base = 0
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+ cls._cluster_spec = {
+ run_config.TaskType.WORKER: [
+ 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
+ ]
+ }
+
+ def setUp(self):
+ self._run_options = config_pb2.RunOptions()
+ self._run_options.experimental.collective_graph_key = 6
+
+ self._sess_config = config_pb2.ConfigProto()
+ self._sess_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
+
+ # We use a different key_base for each test so that collective keys won't be
+ # reused.
+ # TODO(yuefengz, tucker): enable it to reuse collective keys in different
+ # tests.
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000
+ super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+
+ def _get_test_object(self, task_type, task_id, num_gpus=0):
+ distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=num_gpus,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ collective_keys = cross_tower_utils.CollectiveKeys(
+ group_key_start=10 * num_gpus +
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ instance_key_start=num_gpus * 100 +
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ instance_key_with_id_start=num_gpus * 10000 +
+ DistributedCollectiveAllReduceStrategyTest.collective_key_base)
+ distribution._collective_keys = collective_keys
+ distribution._cross_tower_ops._collective_keys = collective_keys
+ return distribution, self._workers[task_id].target
+
+ def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_object(task_type, task_id, num_gpus)
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess, \
+ d.scope():
+ l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker)
+
+ def loss_fn(x):
+ y = array_ops.reshape(l(x), []) - constant_op.constant(1.)
+ return y * y
+
+ # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for
+ # multiple graphs (b/111216820).
+ def grad_fn(x):
+ loss = loss_fn(x)
+ var_list = (
+ variables.trainable_variables() + ops.get_collection(
+ ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ grads = gradients.gradients(loss, var_list)
+ ret = list(zip(grads, var_list))
+ return ret
+
+ def update(v, g):
+ return v.assign_sub(0.05 * g, use_locking=True)
+
+ one = d.broadcast(constant_op.constant([[1.]]))
+
+ def step():
+ """Perform one optimization step."""
+ # Run forward & backward to get gradients, variables list.
+ g_v = d.call_for_each_tower(grad_fn, one)
+ # Update the variables using the gradients and the update() function.
+ before_list = []
+ after_list = []
+ for g, v in g_v:
+ fetched = d.read_var(v)
+ before_list.append(fetched)
+ with ops.control_dependencies([fetched]):
+ # TODO(yuefengz): support non-Mirrored variable as destinations.
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
+ with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ after_list.append(d.read_var(v))
+ return before_list, after_list
+
+ before_out, after_out = step()
+
+ if context.num_gpus() < d._num_gpus_per_worker:
+ return True
+
+ sess.run(
+ variables.global_variables_initializer(), options=self._run_options)
+
+ for i in range(10):
+ b, a = sess.run((before_out, after_out), options=self._run_options)
+ if i == 0:
+ before, = b
+ after, = a
+
+ error_before = abs(before - 1)
+ error_after = abs(after - 1)
+ # Error should go down
+ self.assertLess(error_after, error_before)
+ return error_after < error_before
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ def _test_variable_initialization(self, task_type, task_id, num_gpus):
+ distribution, master_target = self._get_test_object(task_type, task_id,
+ num_gpus)
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess, \
+ distribution.scope():
+
+ def model_fn():
+ x = variable_scope.get_variable(
+ 'x',
+ shape=(2, 3),
+ initializer=init_ops.random_uniform_initializer(
+ 1.0, 10.0, dtype=dtypes.float32))
+ return array_ops.identity(x)
+
+ x = distribution.call_for_each_tower(model_fn)
+ reduced_x = distribution.unwrap(
+ distribution.reduce(
+ variable_scope.VariableAggregation.MEAN, x,
+ destinations='/cpu:0'))[0]
+
+ sess.run(
+ variables.global_variables_initializer(), options=self._run_options)
+ x_value, reduced_x_value = sess.run(
+ [x, reduced_x], options=self._run_options)
+ self.assertTrue(np.array_equal(x_value, reduced_x_value))
+ return np.array_equal(x_value, reduced_x_value)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testVariableInitialization(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_variable_initialization,
+ self._cluster_spec,
+ num_gpus=num_gpus)
+
+
+class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase,
+ parameterized.TestCase):
+
+ def testMinimizeLossGraph(self, num_gpus=2):
+ # Collective ops doesn't support strategy with one device.
+ if context.num_gpus() < num_gpus:
+ return
+ distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=num_gpus)
+ self._test_minimize_loss_graph(distribution)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index b6037d2133..9b5534393e 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -267,9 +267,9 @@ def _group_value_by_device(per_device_values):
This grouping is needed to call the all-reduce library because it expects a
list of the following form:
- [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...
- (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...
- (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...
+ [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
+ [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
+ [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
...
]
@@ -290,7 +290,10 @@ def _group_value_by_device(per_device_values):
return grouped
-def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
+def _ungroup_and_make_mirrored(grouped_reduced,
+ destinations,
+ aggregation,
+ num_between_graph_workers=1):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
@@ -303,6 +306,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
destinations: a list of device strings for returned Mirrored objects.
aggregation: Indicates how a variable will be aggregated. Accepted values
are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
+ num_between_graph_workers: number of workers in the between-graph
+ replication.
Returns:
a list of Mirrored objects.
@@ -311,7 +316,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
for d, per_device_reduced in enumerate(grouped_reduced):
for i, (v, _) in enumerate(per_device_reduced):
if aggregation == vs.VariableAggregation.MEAN:
- index[i][destinations[d]] = v / len(destinations)
+ index[i][destinations[d]] = v / (
+ len(destinations) * num_between_graph_workers)
else:
index[i][destinations[d]] = v
return [value_lib.Mirrored(v) for v in index]
@@ -561,12 +567,12 @@ class AllReduceCrossTowerOps(CrossTowerOps):
def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
- logging.info(
- "batch_all_reduce invoked for batches size = %d with "
+ logging.log_first_n(
+ logging.INFO, "batch_all_reduce invoked for batches size = %d with "
"algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and "
- "agg_small_grads_max_group = %d", len(per_device_values),
- self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes,
- self._agg_small_grads_max_group)
+ "agg_small_grads_max_group = %d" %
+ (len(per_device_values), self._all_reduce_alg, self._num_packs,
+ self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
destinations = per_device_values[0].devices
grouped = _group_value_by_device(per_device_values)
@@ -671,12 +677,13 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
- logging.info(
+ logging.log_first_n(
+ logging.INFO,
"distributed batch_all_reduce invoked for batches size = %d with "
"allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d "
- "and agg_small_grads_max_group = %d", len(per_device_values),
- self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes,
- self._agg_small_grads_max_group)
+ "and agg_small_grads_max_group = %d" %
+ (len(per_device_values), self._all_reduce_spec, self._num_packs,
+ self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
destinations = sorted(per_device_values[0].devices)
device_grads = _group_value_by_device(per_device_values)
@@ -719,6 +726,102 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
aggregation)
+# TODO(yuefengz): support in-graph collective all-reduce.
+class CollectiveAllReduce(CrossTowerOps):
+ """All-reduce cross tower ops using collective ops.
+
+ In the between-graph replicated training, it will still do all-reduces across
+ all workers and then put results on the right destinations.
+ """
+
+ def __init__(self,
+ num_workers=1,
+ num_gpus_per_worker=0,
+ all_reduce_merge_scope=1,
+ collective_keys=None):
+ """Initializes the object.
+
+ Args:
+ num_workers: number of workers in the between-graph replicated training.
+ num_gpus_per_worker: number of GPUs per worker.
+ all_reduce_merge_scope: size of groups into which to partition consecutive
+ gradients grouped under a common 'allreduce' name scope. This is useful
+ for some optimization of collective ops.
+ collective_keys: an optional CollectiveKey object.
+ """
+ self._num_workers = num_workers
+ self._num_gpus_per_worker = num_gpus_per_worker
+ self._all_reduce_merge_scope = all_reduce_merge_scope
+ self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys(
+ )
+ super(CollectiveAllReduce, self).__init__()
+
+ # TODO(yuefengz, tucker): is index slices supported by collective ops?
+ def _reduce(self, aggregation, per_device_value, destinations):
+ all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
+ if destinations is None or _devices_match(per_device_value, destinations):
+ return all_reduced
+ else:
+ index = {}
+ for d in get_devices_from(destinations):
+ # pylint: disable=protected-access
+ if d in all_reduced._index:
+ index[d] = all_reduced._index[d]
+ else:
+ with ops.device(d):
+ index[d] = array_ops.identity(list(all_reduced._index.values())[0])
+ return value_lib.Mirrored(index)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
+
+ def _batch_all_reduce(self, aggregation, per_device_values):
+ """All-reduce across all workers in a batch."""
+ if context.executing_eagerly():
+ raise ValueError("Eager mode with collective ops is not supported yet.")
+
+ logging.log_first_n(
+ logging.INFO, "Collective All-reduce invoked with batches size = %d, "
+ "num_workers = %d" % (len(per_device_values), self._num_workers), 10)
+
+ grouped_by_tower = _group_value_by_device(per_device_values)
+
+ grouped_by_var = list(zip(*grouped_by_tower))
+ # grouped_by_var is grouped by variables and takes the following format:
+ # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
+ # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
+ # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
+ # ...
+ # ]
+ chunked_gv = [
+ grouped_by_var[x:x + self._all_reduce_merge_scope]
+ for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope)
+ ]
+
+ reduced_gv_list = []
+ for chunk in chunked_gv:
+ with ops.name_scope("allreduce"):
+ for grad_and_vars in chunk:
+ scaled_grads = [g for g, _ in grad_and_vars]
+ collective_reduced = cross_tower_utils.build_collective_reduce(
+ scaled_grads, self._num_workers, self._collective_keys, "Add",
+ "Id")
+ result = []
+ for (_, v), g in zip(grad_and_vars, collective_reduced):
+ result.append([g, v])
+ reduced_gv_list.append(result)
+
+ new_tower_grads = [list(x) for x in zip(*reduced_gv_list)]
+ return _ungroup_and_make_mirrored(
+ new_tower_grads,
+ per_device_values[0].devices,
+ aggregation,
+ num_between_graph_workers=self._num_workers)
+
+
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 6a780ff60f..7b6c1843eb 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -21,13 +21,17 @@ from __future__ import print_function
import itertools
from absl.testing import parameterized
+import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -376,5 +380,173 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
self._testReductionAndBroadcast(cross_tower_ops, distribution)
+class MultiWorkerCollectiveAllReduceTest(
+ multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+
+ collective_key_base = 10000
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+ cls._cluster_spec = {
+ run_config.TaskType.WORKER: [
+ "fake_worker_0", "fake_worker_1", "fake_worker_2"
+ ]
+ }
+
+ def setUp(self):
+ super(MultiWorkerCollectiveAllReduceTest, self).setUp()
+ # Reusing keys are not supported well. So we have to give a different
+ # collective key base for different tests.
+ MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000
+
+ def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False):
+ collective_keys = cross_tower_utils.CollectiveKeys(
+ group_key_start=10 * num_gpus +
+ MultiWorkerCollectiveAllReduceTest.collective_key_base,
+ instance_key_start=num_gpus * 100 +
+ MultiWorkerCollectiveAllReduceTest.collective_key_base,
+ instance_key_with_id_start=num_gpus * 10000 +
+ MultiWorkerCollectiveAllReduceTest.collective_key_base)
+ if local_mode:
+ collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
+ 1, num_gpus, collective_keys=collective_keys)
+ if num_gpus:
+ devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
+ else:
+ devices = ["/device:CPU:0"]
+ return collective_all_reduce_ops, devices, "local"
+ else:
+ collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce(
+ 3, num_gpus, collective_keys=collective_keys)
+ if num_gpus:
+ devices = [
+ "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i)
+ for i in range(num_gpus)
+ ]
+ else:
+ devices = ["/job:%s/task:%d" % (task_type, task_id)]
+ return collective_all_reduce_ops, devices, self._workers[task_id].target
+
+ def _assert_values_equal(self, left, right, sess):
+ if isinstance(left, list):
+ for l, r in zip(left, right):
+ self._assert_values_equal(l, r, sess)
+ else:
+ self.assertEqual(type(left), type(right))
+ self.assertEqual(set(left.devices), set(right.devices))
+
+ run_options = config_pb2.RunOptions()
+ run_options.experimental.collective_graph_key = 6
+
+ left_values = np.array(
+ sess.run(list(left._index.values()), options=run_options)).flatten()
+ right_values = np.array(list(right._index.values())).flatten()
+ self.assertEqual(len(left_values), len(right_values))
+ for l, r in zip(left_values, right_values):
+ self.assertEqual(l, r)
+
+ def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False):
+ collective_all_reduce, devices, master_target = self._get_test_objects(
+ task_type, task_id, num_gpus, local_mode=local_mode)
+ if local_mode:
+ num_workers = 1
+ worker_device = None
+ else:
+ num_workers = len(self._workers)
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ with ops.Graph().as_default(), \
+ ops.device(worker_device), \
+ self.test_session(target=master_target) as sess:
+ # Collective ops doesn't support scalar tensors, so we have to construct
+ # 1-d tensors.
+ values = [constant_op.constant([float(d)]) for d in range(len(devices))]
+ per_device = _make_per_device(values, devices)
+ mean = np.array([(len(devices) - 1.) / 2.])
+
+ values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
+ per_device_2 = _make_per_device(values_2, devices)
+ mean_2 = np.array([mean[0] + 1.])
+
+ destination_mirrored = _fake_mirrored(1., devices)
+ destination_different = _fake_mirrored(1., _cpu_device)
+ destination_str = _cpu_device
+ destination_list = devices
+
+ all_destinations = [
+ None, destination_mirrored, destination_different, destination_str,
+ destination_list
+ ]
+
+ # test reduce()
+ for destinations in all_destinations:
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device,
+ destinations=destinations),
+ _fake_mirrored(mean, destinations or per_device), sess)
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device_2,
+ destinations=destinations),
+ _fake_mirrored(mean_2, destinations or per_device), sess)
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.SUM,
+ per_device,
+ destinations=destinations),
+ _fake_mirrored(mean * len(devices) * num_workers, destinations or
+ per_device), sess)
+ self._assert_values_equal(
+ collective_all_reduce.reduce(
+ vs.VariableAggregation.SUM,
+ per_device_2,
+ destinations=destinations),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or
+ per_device), sess)
+
+ # test batch_reduce()
+ for d1, d2 in itertools.product(all_destinations, all_destinations):
+ self._assert_values_equal(
+ collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN,
+ [(per_device, d1),
+ (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean, d1 or per_device),
+ _fake_mirrored(mean_2, d2 or per_device_2)
+ ], sess)
+ self._assert_values_equal(
+ collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
+ [(per_device, d1),
+ (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean * len(devices) * num_workers, d1 or
+ per_device),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or
+ per_device_2)
+ ], sess)
+
+ return True
+
+ @combinations.generate(
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ def testReductionDistributed(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(self._test_reduction, self._cluster_spec,
+ num_gpus)
+
+ # Collective ops doesn't support strategy with one device.
+ def testReductionLocal(self, num_gpus=2):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_reduction, self._cluster_spec, num_gpus, local_mode=True)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 2bb088e704..24cb08fb48 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -19,13 +19,16 @@ from __future__ import division
from __future__ import print_function
import collections as pycoll
+import threading
from tensorflow.contrib import nccl
from tensorflow.contrib.all_reduce.python import all_reduce
from tensorflow.contrib.distribute.python import values as value_lib
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
@@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads):
return small_grads, large_grads
+# threading.Lock() cannot be pickled and therefore cannot be a field of
+# CollectiveKeys.
+_lock = threading.Lock()
+
+
+# TODO(yuefengz): use random key starts to avoid reusing keys?
+class CollectiveKeys(object):
+ """Class that manages collective keys.
+
+ We need to manage three different keys for collective:
+
+ *Group key*: an integer key to identify the set of cooperative devices.
+ Collective ops work under the same set of devices must using the same group
+ key.
+
+ *Instance key*: an integer key to identify the set of same counterpart of
+ tensors on different devices in a device group that need to be all-reduced.
+
+ "Graph key": an integer key that is unique key graph. This is used to support
+ multiple graphs per client session. It must be non-zero and set in the
+ `config` argument of each call to `session.run`.
+ """
+
+ def __init__(self,
+ group_key_start=1,
+ instance_key_start=100,
+ instance_key_with_id_start=10000):
+ """Initializes the object.
+
+ Args:
+ group_key_start: the starting integer of group key.
+ instance_key_start: the starting integer of instance key.
+ instance_key_with_id_start: the starting integer of instance key that is
+ recorded with an id.
+ """
+ self._group_key = group_key_start
+ self._group_key_table = dict()
+
+ # For instance keys with ids
+ self._instance_key_id_to_key_table = dict()
+ self._instance_key_with_id_counter = instance_key_with_id_start
+
+ # For instance keys without ids
+ self._instance_key_start = instance_key_start
+
+ self._thread_local = threading.local()
+
+ def _get_thread_local_object(self):
+ # We make instance key without key ids thread local so that it will work
+ # with MirroredStrategy and distribute coordinator.
+ if not hasattr(self._thread_local, 'instance_key'):
+ self._thread_local.instance_key = self._instance_key_start
+ return self._thread_local
+
+ def get_group_key(self, devices):
+ """Returns a group key for the set of devices.
+
+ Args:
+ devices: list of strings naming devices in a collective group.
+
+ Returns:
+ int key uniquely identifying the set of device names.
+ """
+ parsed = [pydev.DeviceSpec.from_string(d) for d in devices]
+ # In the between-graph replicated training, different workers need to get
+ # the same device key. So we remove the task_type and task_id from the
+ # devices.
+ # TODO(yuefengz): in the in-graph replicated training, we need to include
+ # task_type and task_id.
+ names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
+ key_id = ','.join(names)
+ with _lock:
+ if key_id not in self._group_key_table:
+ new_key = self._group_key
+ self._group_key += 1
+ self._group_key_table[key_id] = new_key
+ return self._group_key_table[key_id]
+
+ def get_instance_key(self, key_id=None):
+ """Returns a new instance key for use in defining a collective op.
+
+ Args:
+ key_id: optional string. If set, key will be recorded and the same key
+ will be returned when the same key_id is provided. If not, an increasing
+ instance key will be returned.
+ """
+ if key_id:
+ with _lock:
+ if key_id not in self._instance_key_id_to_key_table:
+ self._instance_key_with_id_counter += 1
+ self._instance_key_id_to_key_table[key_id] = (
+ self._instance_key_with_id_counter)
+ return self._instance_key_id_to_key_table[key_id]
+ else:
+ v = self._get_thread_local_object().instance_key
+ self._get_thread_local_object().instance_key += 1
+ return v
+
+
+def build_collective_reduce(input_tensors,
+ num_workers,
+ collective_keys,
+ reduction_op='Add',
+ unary_op='Id'):
+ """Build a subgraph that does one full all-reduce, using the collective Op.
+
+ Args:
+ input_tensors: tensors within a single worker graph that are to be reduced
+ together; must be one per device.
+ num_workers: total number of workers with identical independent graphs that
+ will be doing this same reduction. The reduction will actually include
+ the corresponding tensors at all these workers.
+ collective_keys: a CollectiveKeys object.
+ reduction_op: string naming the reduction op.
+ unary_op: string naming the unary final op.
+
+ Returns:
+ An array of final tensors, one per device, computed by the full reduction.
+
+ Raises:
+ ValueError: There must be at least two tensors over all the workers.
+ """
+ group_size = len(input_tensors) * num_workers
+ if group_size < 2:
+ raise ValueError('num_workers * len(input_tensors) must be 2 or greater')
+ devices = [t.device for t in input_tensors]
+ num_devices = len(devices)
+ group_key = collective_keys.get_group_key(devices)
+ instance_key = collective_keys.get_instance_key()
+ out_tensors = []
+ subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec
+ for d in range(num_devices):
+ with ops.device(devices[d]):
+ reduce_op = collective_ops.all_reduce(
+ input_tensors[d], group_size, group_key, instance_key, reduction_op,
+ unary_op, subdiv_offsets)
+ out_tensors.append(reduce_op)
+ return out_tensors
+
+
def sum_grad_and_var_all_reduce(grad_and_vars,
num_workers,
alg,
@@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars,
else:
raise ValueError('unsupported all_reduce alg: ', alg)
- result = []
- for (_, v), g in zip(grad_and_vars, summed_grads):
- result.append([g, v])
- return result
+ result = []
+ for (_, v), g in zip(grad_and_vars, summed_grads):
+ result.append([g, v])
+ return result
def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg,
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 34b8a54d7b..fbdb376fcc 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -12,33 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for Keras Sequential and Functional models."""
+"""Tests for tf.keras models using DistributionStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
-
import numpy as np
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import values
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import keras as keras_lib
from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
+
_RANDOM_SEED = 1337
_TRAIN_SIZE = 200
_INPUT_SIZE = (10,)
_NUM_CLASS = 2
+# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as
+# part of the tf.keras unit tests suite.
def simple_sequential_model():
model = keras.models.Sequential()
model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE))
@@ -84,7 +91,7 @@ def get_ds_test_input_fn():
return dataset
-class TestKerasDistributionStrategy(test_util.TensorFlowTestCase):
+class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
def setUp(self):
self._base_dir = os.path.join(self.get_temp_dir(),
@@ -145,5 +152,367 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+
+class TestWithDistributionStrategy(test.TestCase):
+
+ def test_validating_dataset_input_tensors_with_shape_mismatch(self):
+ with self.test_session():
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
+ '/device:CPU:0'])
+ a = constant_op.constant([1, 2], shape=(1, 2))
+ b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
+ x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
+ y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
+ with strategy.scope():
+ # Removed device and input tensor shape details from the error message
+ # since the order of the device and the corresponding input tensor shape
+ # is not deterministic over different runs.
+ with self.assertRaisesRegexp(ValueError,
+ 'Input tensor shapes do not match for '
+ 'distributed tensor inputs '
+ 'DistributedValues:.+'):
+ distributed_training_utils.validate_distributed_dataset_inputs(
+ strategy, x, y)
+
+ def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
+ with self.test_session():
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
+ '/device:CPU:0'])
+ a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
+ b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
+ x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
+ y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
+ with strategy.scope():
+ # Removed device and input tensor dtype details from the error message
+ # since the order of the device and the corresponding input tensor dtype
+ # is not deterministic over different runs.
+ with self.assertRaisesRegexp(ValueError,
+ 'Input tensor dtypes do not match for '
+ 'distributed tensor inputs '
+ 'DistributedValues:.+'):
+ distributed_training_utils.validate_distributed_dataset_inputs(
+ strategy, x, y)
+
+ def test_calling_model_on_same_dataset(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ # Call fit with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+ model.predict(dataset, steps=2)
+
+ def test_fit_eval_and_predict_methods_on_dataset(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
+ '/device:CPU:0'])
+
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ # Test with validation data
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ validation_data=dataset, validation_steps=2)
+
+ def test_unsupported_features(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not supported '
+ 'when input `x` is a dataset or a dataset iterator'):
+ model.fit(dataset,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_split=0.5, validation_steps=2)
+
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ ValueError, 'sample_weight is currently not supported when using '
+ 'DistributionStrategy.'):
+ model.fit(
+ dataset,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
+
+ # Test with not specifying the `steps` argument.
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(dataset, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(dataset, verbose=0)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(dataset, verbose=0)
+
+ def test_calling_with_unsupported_predefined_callbacks(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ def schedule(_):
+ return 0.001
+ with self.assertRaisesRegexp(ValueError,
+ 'LearningRateScheduler callback is not '
+ 'supported with DistributionStrategy.'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
+
+ with self.assertRaisesRegexp(ValueError,
+ 'ReduceLROnPlateau callback is not '
+ 'supported with DistributionStrategy.'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.ReduceLROnPlateau()])
+ with self.assertRaisesRegexp(ValueError,
+ 'histogram_freq in the TensorBoard callback '
+ 'is not supported when using '
+ 'DistributionStrategy.'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)])
+
+ def test_dataset_input_shape_validation(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+
+ model.compile(optimizer, loss, distribute=strategy)
+
+ # User forgets to batch the dataset
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'expected input to have 2 dimensions'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+
+ # Wrong input shape
+ inputs = np.zeros((10, 5), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'expected input to have shape'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+
+ def test_learning_phase_value(self):
+ # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
+ # meaningful values. Currently we don't pass the learning phase if the
+ # Lambda layer uses the learning phase.
+ with self.test_session():
+ x = keras.layers.Input(shape=(16,), name='input')
+ y = keras.layers.Dense(16)(x)
+ z = keras.layers.Dropout(0.9999)(y)
+ model = keras.Model(x, z)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.005)
+ loss = 'mse'
+ metrics = ['acc']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
+ '/device:CPU:0'])
+
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ inputs = np.random.rand(10, 16)
+ targets = np.ones((10, 16), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(8)
+
+ hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1)
+ self.assertEqual(hist.history['acc'][0], 1)
+
+ evaluate_output = model.evaluate(dataset, steps=20)
+ self.assertEqual(evaluate_output[1], 0)
+
+ predict_output = model.predict(dataset, steps=1)
+ self.assertNotEqual(np.mean(predict_output), 0)
+
+
+class LossMaskingWithDistributionStrategyTest(test.TestCase):
+
+ def test_masking(self):
+ with self.test_session():
+ np.random.seed(1337)
+ x = np.array([[[1], [1]], [[0], [0]]])
+ model = keras.models.Sequential()
+ model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(1, kernel_initializer='one')))
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+
+ model.compile(loss='mse',
+ optimizer=gradient_descent.GradientDescentOptimizer(0.01),
+ distribute=strategy)
+ y = np.array([[[1], [1]], [[1], [1]]])
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2)
+ self.assertEqual(hist.history['loss'][0], 0)
+
+
+class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
+
+ def test_batchnorm_correctness(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
+ model.add(norm)
+ strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0',
+ '/device:GPU:0'])
+ model.compile(loss='mse',
+ optimizer=gradient_descent.GradientDescentOptimizer(0.01),
+ distribute=strategy)
+
+ # centered on 5.0, variance 10.0
+ x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(32)
+
+ model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
+ out = model.predict(dataset, steps=2)
+ out -= keras.backend.eval(norm.beta)
+ out /= keras.backend.eval(norm.gamma)
+ np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
+ np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+
+
+class CorrectnessWithDistributionStrategyTest(test.TestCase):
+
+ def test_correctness(self):
+ with self.test_session():
+ keras.backend.set_image_data_format('channels_last')
+ num_samples = 10000
+ x_train = np.random.rand(num_samples, 1)
+ y_train = 3 * x_train
+ x_train = x_train.astype('float32')
+ y_train = y_train.astype('float32')
+
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+
+ # With DistributionStrategy
+ dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
+ dataset_with = dataset_with.batch(32)
+ strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
+ '/device:GPU:0'],
+ prefetch_on_device=False)
+
+ model.compile(loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ distribute=strategy)
+ model.fit(x=dataset_with, epochs=1, steps_per_epoch=310)
+ wts_with_ds = model.get_weights()
+
+ x_predict = [[1], [2], [3], [4]]
+ predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict,
+ x_predict))
+ predict_dataset_with = predict_dataset_with.batch(2)
+ predict_with_ds = model.predict(predict_dataset_with, steps=1)
+ predict_with_ds = np.reshape(predict_with_ds, (4, 1))
+
+ # Without DistributionStrategy
+ dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train,
+ y_train))
+ dataset_without = dataset_without.batch(64)
+
+ model.compile(loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5))
+ model.fit(x=dataset_without, epochs=1, steps_per_epoch=310)
+ wts_without_ds = model.get_weights()
+
+ x_predict = [[1], [2], [3], [4]]
+ predict_dataset_without = dataset_ops.Dataset.from_tensor_slices((
+ x_predict, x_predict))
+ predict_dataset_without = predict_dataset_without.batch(4)
+ predict_without_ds = model.predict(predict_dataset_without, steps=1)
+
+ # Verify that the weights are the same within some limits of tolerance.
+ np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3)
+ # Verify that the predicted outputs are the same within some limits of
+ # tolerance.
+ np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 6c6bf14309..2f3d6bdd3f 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -19,7 +19,6 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import combinations
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import test
@@ -183,7 +182,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _dataset_fn():
dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
# Want to produce a fixed, known shape, so drop remainder when batching.
- dataset = dataset.apply(batching.batch_and_drop_remainder(4))
+ dataset = dataset.batch(4, drop_remainder=True)
return dataset
def _expected_fn(num_batches):
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index eb2d102012..0c26ae8dbc 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -209,6 +209,75 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
return values.Mirrored(value_updates)
+def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Get synchronization value
+ synchronization = kwargs.get("synchronization",
+ variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
+ kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [
+ variable_scope.VariableAggregation.NONE,
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = real_mirrored_creator(devices, *args, **kwargs)
+
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]], aggregation)
+ else:
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
+
+
class MirroredStrategy(distribute_lib.DistributionStrategy):
"""Mirrors vars to distribute across multiple devices on a single machine.
@@ -243,54 +312,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
- # Figure out what collections this variable should be added to.
- # We'll add the MirroredVariable to those collections instead.
- collections = kwargs.pop("collections", None)
- if collections is None:
- collections = [ops.GraphKeys.GLOBAL_VARIABLES]
- kwargs["collections"] = []
-
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- # Get synchronization value
- synchronization = kwargs.get(
- "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
- if synchronization == variable_scope.VariableSynchronization.NONE:
- raise ValueError("`NONE` variable synchronization mode is not "
- "supported with `Mirrored` distribution strategy. Please"
- " change the `synchronization` for variable: " +
- kwargs["name"])
- elif synchronization == variable_scope.VariableSynchronization.ON_READ:
- # Variables that are to be synced on read are tower local.
- is_tower_local = True
- kwargs["trainable"] = False
- elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
- synchronization == variable_scope.VariableSynchronization.AUTO):
- # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
- is_tower_local = False
- else:
- raise ValueError("Invalid variable synchronization mode: " +
- synchronization + " for variable: " + kwargs["name"])
-
- # Get aggregation value
- aggregation = kwargs.pop("aggregation",
- variable_scope.VariableAggregation.NONE)
- if aggregation not in [
- variable_scope.VariableAggregation.NONE,
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
- raise ValueError("Invalid variable aggregation mode: " + aggregation +
- " for variable: " + kwargs["name"])
-
- # Ignore user-specified caching device, not needed for mirrored variables.
- kwargs.pop("caching_device", None)
-
- # TODO(josh11b,apassos): It would be better if variable initialization
- # was never recorded on the tape instead of having to do this manually
- # here.
- with tape.stop_recording():
+ def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
index = {}
for i, d in enumerate(devices):
with ops.device(d):
@@ -314,27 +339,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
+ return index
- if is_tower_local:
- result = values.TowerLocalVariable(index, index[devices[0]],
- aggregation)
- else:
- result = values.MirroredVariable(index, index[devices[0]], aggregation)
-
- if not context.executing_eagerly():
- g = ops.get_default_graph()
- # If "trainable" is True, next_creator() will add the member variables
- # to the TRAINABLE_VARIABLES collection, so we manually remove
- # them and replace with the MirroredVariable. We can't set
- # "trainable" to False for next_creator() since that causes functions
- # like implicit_gradients to skip those variables.
- if kwargs.get("trainable", True):
- collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
- l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
- for v in index.values():
- l.remove(v)
- g.add_to_collections(collections, result)
- return result
+ return _create_mirrored_variable(devices, _real_mirrored_creator, *args,
+ **kwargs)
def distribute_dataset(self, dataset_fn):
return values.PerDeviceDataset(
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index fa479918bd..249de01f08 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -20,11 +20,14 @@ from __future__ import print_function
import contextlib
import copy
+import threading
+import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
-from tensorflow.python.eager import test
+from tensorflow.python.estimator import run_config
+from tensorflow.python.platform import test
from tensorflow.python.framework import test_util
@@ -35,6 +38,12 @@ def create_in_process_cluster(num_workers, num_ps):
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+ # Enable collective ops which has no impact on non-collective ops.
+ # TODO(yuefengz, tucker): removing this after we move the initialization of
+ # collective mgr to the session level.
+ worker_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
+
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
@@ -43,7 +52,7 @@ def create_in_process_cluster(num_workers, num_ps):
# We could've started the server in another process, we could then kill that
# process to terminate the server. The reasons why we don't want multiple
# processes are
- # 1) it is more difficult to manage these processes
+ # 1) it is more difficult to manage these processes;
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
@@ -51,7 +60,8 @@ def create_in_process_cluster(num_workers, num_ps):
num_workers,
num_ps=num_ps,
worker_config=worker_config,
- ps_config=ps_config)
+ ps_config=ps_config,
+ protocol='grpc')
class MultiWorkerTestBase(test.TestCase):
@@ -60,11 +70,18 @@ class MultiWorkerTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- workers, _ = create_in_process_cluster(num_workers=2, num_ps=0)
- cls._master_target = workers[0].target
+ cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0)
+
+ def setUp(self):
+ # We only cache the session in one test because another test may have a
+ # different session config or master target.
+ self._thread_local = threading.local()
+ self._thread_local.cached_session = None
+ self._result = 0
+ self._lock = threading.Lock()
@contextlib.contextmanager
- def test_session(self, graph=None, config=None):
+ def test_session(self, graph=None, config=None, target=None):
"""Create a test session with master target set to the testing cluster.
This overrides the base class' method, removes arguments that are not needed
@@ -75,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase):
graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
session.
+ target: the target of session to connect to.
Yields:
A Session object that should be used as a context manager to surround
@@ -94,13 +112,46 @@ class MultiWorkerTestBase(test.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
if graph is None:
- if self._cached_session is None: # pylint: disable=access-member-before-definition
- self._cached_session = session.Session(
- graph=None, config=config, target=self._master_target)
- sess = self._cached_session
+ if getattr(self._thread_local, 'cached_session', None) is None:
+ self._thread_local.cached_session = session.Session(
+ graph=None, config=config, target=target or self._workers[0].target)
+ sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
else:
with session.Session(
- graph=graph, config=config, target=self._master_target) as sess:
+ graph=graph, config=config, target=target or
+ self._workers[0].target) as sess:
yield sess
+
+ def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
+ **kwargs):
+ result = client_fn(task_type, task_id, num_gpus, *args, **kwargs)
+ if np.all(result):
+ with self._lock:
+ self._result += 1
+
+ def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
+ **kwargs):
+ """Runs several clients for between-graph replication.
+
+ Args:
+ client_fn: a function that needs to accept `task_type`, `task_id`,
+ `num_gpus` and returns True if it succeeds.
+ cluster_spec: a dict specifying jobs in a cluster.
+ num_gpus: number of GPUs per worker.
+ *args: will be passed to `client_fn`.
+ **kwargs: will be passed to `client_fn`.
+ """
+ threads = []
+ for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]:
+ for task_id in range(len(cluster_spec.get(task_type, []))):
+ t = threading.Thread(
+ target=self._run_client,
+ args=(client_fn, task_type, task_id, num_gpus) + args,
+ kwargs=kwargs)
+ t.start()
+ threads.append(t)
+ for t in threads:
+ t.join()
+ self.assertEqual(self._result, len(threads))
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index ad538b9e8e..cf29c0ed91 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import contextlib
import json
import threading
from absl.testing import parameterized
@@ -26,8 +25,6 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
@@ -43,12 +40,19 @@ from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
-class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
+class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
+ parameterized.TestCase):
@classmethod
def setUpClass(cls):
cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=2)
+ cls._cluster_spec = {
+ run_config.TaskType.WORKER: [
+ 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
+ ],
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
def setUp(self):
self._result = 0
@@ -57,40 +61,34 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
self._init_reached = 0
self._finish_condition = threading.Condition()
self._finish_reached = 0
+ super(ParameterServerStrategyTest, self).setUp()
+
+ def _get_test_objects(self, task_type, task_id, num_gpus):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=num_gpus)
+ if not task_type:
+ return distribution, ''
- def _get_ps_distribution_strategy(self, task_type, task_index, num_gpus=0):
tf_config = {
- 'cluster': {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ],
- run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
- },
+ 'cluster': self._cluster_spec,
'task': {
'type': task_type,
- 'index': task_index
+ 'index': task_id
}
}
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=num_gpus)
with self._lock:
# Accessing environment variables should be protected by locks because
# environment variables are shared by all threads.
with test.mock.patch.dict('os.environ',
{'TF_CONFIG': json.dumps(tf_config)}):
distribution.configure()
- return distribution
-
- @contextlib.contextmanager
- def _test_session(self, target):
- config = config_pb2.ConfigProto(allow_soft_placement=True)
- config.graph_options.optimizer_options.opt_level = -1
- with session.Session(graph=None, config=config, target=target) as sess:
- yield sess
+ return distribution, self._workers[task_id].target
- def _test_device_assignment_distributed(self, d, num_gpus=0):
+ def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
+ worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
+ d, _ = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self._test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._workers[0].target) as sess, \
d.scope():
# Define a variable outside the call_for_each_tower scope. This is not
@@ -108,12 +106,9 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
a = constant_op.constant(1.0)
b = constant_op.constant(2.0)
c = a + b
- self.assertEqual(a.device,
- '/job:worker/replica:0/task:1/%s' % last_part_device)
- self.assertEqual(b.device,
- '/job:worker/replica:0/task:1/%s' % last_part_device)
- self.assertEqual(c.device,
- '/job:worker/replica:0/task:1/%s' % last_part_device)
+ self.assertEqual(a.device, worker_device + '/' + last_part_device)
+ self.assertEqual(b.device, worker_device + '/' + last_part_device)
+ self.assertEqual(c.device, worker_device + '/' + last_part_device)
# The device scope is ignored for variables but not for normal ops.
with ops.device('/job:worker/task:0'):
@@ -143,13 +138,12 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
z_add = z.assign_add(y)
with ops.control_dependencies([z_add]):
f = z + c
- self.assertEqual(f.device,
- '/job:worker/replica:0/task:1/%s' % last_part_device)
+ self.assertEqual(f.device, worker_device + '/' + last_part_device)
# The device scope would merge with the default worker device.
with ops.device('/CPU:1'):
g = e + 1.0
- self.assertEqual(g.device, '/job:worker/replica:0/task:1/device:CPU:1')
+ self.assertEqual(g.device, worker_device + '/device:CPU:1')
# Ths ops.colocate_with will be ignored when defining a variale but not
# for a normal tensor.
@@ -182,8 +176,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
def testDeviceAssignmentDistributed(self, num_gpus):
- d = self._get_ps_distribution_strategy('worker', 1, num_gpus=num_gpus)
- self._test_device_assignment_distributed(d, num_gpus=num_gpus)
+ self._test_device_assignment_distributed('worker', 1, num_gpus)
def _test_device_assignment_local(self,
d,
@@ -191,7 +184,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
variable_device='CPU',
num_gpus=0):
with ops.Graph().as_default(), \
- self._test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._workers[0].target) as sess, \
d.scope():
def model_fn():
@@ -272,30 +265,33 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- def testDeviceAssignmentLocal(self):
+ def testDeviceAssignmentLocalCPU(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=0)
self._test_device_assignment_local(
distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
+ def testDeviceAssignmentLocalOneGPU(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=1)
self._test_device_assignment_local(
distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
+ def testDeviceAssignmentLocalTwoGPUs(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
self._test_device_assignment_local(
distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
- def _test_simple_increment(self, d, task_type, task_index, master_target):
+ def _test_simple_increment(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
if hasattr(d, '_cluster_spec') and d._cluster_spec:
num_workers = len(d._cluster_spec.as_dict().get('worker',
['dummy_worker']))
else:
num_workers = 1
with ops.Graph().as_default(), \
- self._test_session(target=master_target) as sess, \
+ self.test_session(target=master_target) as sess, \
d.scope():
def model_fn():
@@ -314,7 +310,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
if context.num_gpus() < d._num_gpus_per_worker:
return True
- if task_index == 0:
+ if task_id == 0:
variables.global_variables_initializer().run()
# Workers waiting for chief worker's initializing variables.
@@ -341,9 +337,10 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and
y_val == 20.0 + 1.0 * num_workers * d.num_towers)
- def _test_minimize_loss_graph(self, d, task_type, task_index, master_target):
+ def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self._test_session(target=master_target) as sess, \
+ self.test_session(target=master_target) as sess, \
d.scope():
l = core.Dense(1, use_bias=False)
@@ -390,7 +387,7 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
if context.num_gpus() < d._num_gpus_per_worker:
return True
- if task_index == 0:
+ if task_id == 0:
variables.global_variables_initializer().run()
# Workers waiting for chief worker's initializing variables.
@@ -413,42 +410,20 @@ class ParameterServerStrategyTest(test.TestCase, parameterized.TestCase):
self.assertLess(error_after, error_before)
return error_after < error_before
- def _run_client(self, index, model_fn, num_gpus):
- task_type = run_config.TaskType.WORKER
- result = model_fn(
- self._get_ps_distribution_strategy(task_type, index, num_gpus=num_gpus),
- task_type, index, self._workers[index].target)
- if result:
- with self._lock:
- self._result += 1
-
- def _run_multiple_clients(self, num_clients, model_fn, num_gpus=0):
- threads = []
- for i in range(num_clients):
- t = threading.Thread(
- target=self._run_client, args=(i, model_fn, num_gpus))
- t.start()
- threads.append(t)
- for t in threads:
- t.join()
-
def testSimpleBetweenGraph(self):
- self._run_multiple_clients(3, self._test_simple_increment)
- self.assertEqual(self._result, 3)
+ self._run_between_graph_clients(self._test_simple_increment,
+ self._cluster_spec, 0)
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
def testLocalSimpleIncrement(self, num_gpus):
- d = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=num_gpus)
- self._test_simple_increment(d, 'dummy_worker', 0, '')
+ self._test_simple_increment(None, 0, num_gpus)
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
def testMinimizeLossGraph(self, num_gpus):
- self._run_multiple_clients(
- 3, self._test_minimize_loss_graph, num_gpus=num_gpus)
- self.assertEqual(self._result, 3)
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index bc53898539..f5497e0b21 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -21,15 +21,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import tpu
+from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import one_device_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.training import device_util
from tensorflow.python.util import nest
@@ -39,11 +43,11 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def __init__(self, num_cores_per_host=2):
# TODO(isaprykin): Generalize the defaults. They are currently tailored for
# the unit test.
- super(TPUStrategy, self).__init__('/cpu:0')
+ super(TPUStrategy, self).__init__('/device:CPU:0')
# TODO(isaprykin): Auto-detect number of cores and hosts.
self._num_cores_per_host = num_cores_per_host
# TODO(priyag): This should not be hardcoded here.
- self._host = '/task:0/device:CPU:0'
+ self._host = '/device:CPU:0'
def distribute_dataset(self, dataset_fn):
# TODO(priyag): Perhaps distribute across cores here.
@@ -54,7 +58,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
def _run_steps_on_dataset(self, fn, iterator, iterations,
initial_loop_values=None):
- # Enqueue ops
+
shapes = nest.flatten(iterator.output_shapes)
if any([not s.is_fully_defined() for s in shapes]):
raise ValueError(
@@ -93,9 +97,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
[constant_op.constant(0)],
parallel_iterations=1)
- # Dequeue ops
def dequeue_fn():
- dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
+ dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
return nest.pack_sequence_as(iterator.output_shapes, dequeued)
# Wrap `fn` for repeat.
@@ -110,17 +113,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
with ops.control_dependencies([fn_result]):
return array_ops.identity(ctx.last_step_outputs)
- # Repeat
# TODO(sourabhbajaj): The input to while loop should be based on the output
# type of the step_fn
def iterate_on_tpu():
- return tpu.repeat(iterations, run_fn, [initial_loop_values])
+ return training_loop.repeat(iterations, run_fn, [initial_loop_values])
- # Re-write and distribute computation.
- # TODO(sourabhbajaj): Convert the output to PerDevice variable and
- # implement support for that in reduce.
- last_step_tensor_outputs = tpu.batch_parallel(
- iterate_on_tpu, [], num_shards=self._num_cores_per_host)
+ replicate_inputs = [[]] * self._num_cores_per_host
+ outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
+ last_step_tensor_outputs = [list(x) for x in zip(*outputs)]
# Take index [0] of last_step_tensor_outputs as we wrapped
# initial_loop_values in a list in the `repeat` call.
@@ -139,11 +139,32 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return [tpu.shutdown_system()]
def _reduce(self, aggregation, value, destinations):
- del destinations # TPU is graph mode only. Rely on implicit Send/Recv.
+ graph = ops.get_default_graph()
+ context = graph._get_control_flow_context() # pylint: disable=protected-access
+ # If we're inside the ReplicateContext, reduction should be done using
+ # CrossReplicaSum while outside we can directly use an add_n op.
+ while context:
+ if isinstance(context, tpu.TPUReplicateContext):
+ if aggregation == vs.VariableAggregation.MEAN:
+ # TODO(jhseu): Revisit once we support model-parallelism.
+ value *= (1. / self._num_cores_per_host)
+ return tpu_ops.cross_replica_sum(value)
+ context = context.outer_context
+
+ # Validate that the destination is same as the host device
+ # Note we don't do this when in replicate context as the reduction is
+ # performed on the TPU device itself.
+ devices = cross_tower_ops_lib.get_devices_from(destinations)
+ if len(devices) == 1:
+ assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
+ self._host)
+ else:
+ raise ValueError('Multiple devices are not supported for TPUStrategy')
+
+ output = math_ops.add_n(value)
if aggregation == vs.VariableAggregation.MEAN:
- # TODO(jhseu): Revisit once we support model-parallelism.
- value *= (1. / self._num_cores_per_host)
- return tpu_ops.cross_replica_sum(value)
+ return output * (1. / len(value))
+ return output
@property
def num_towers(self):
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index e31dbbe80f..16844e0d68 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -22,12 +22,9 @@ from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.saver import BaseSaverBuilder
-class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
+class Iterator(iterator_ops.EagerIterator):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
NOTE: Unlike the iterator created by the
@@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
return super(Iterator, self)._next_internal()
-
- # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
- # attributes(potential).
-
- class _Saveable(BaseSaverBuilder.SaveableObject):
- """SaveableObject for saving/restoring iterator state."""
-
- def __init__(self, iterator_resource, name):
- serialized_iterator = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- specs = [
- BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
- ]
- # pylint: disable=protected-access
- super(Iterator._Saveable, self).__init__(iterator_resource, specs, name)
-
- def restore(self, restored_tensors, restored_shapes):
- with ops.colocate_with(self.op):
- return gen_dataset_ops.deserialize_iterator(self.op,
- restored_tensors[0])
-
- def _gather_saveables_for_checkpoint(self):
-
- def _saveable_factory(name):
- return self._Saveable(self._resource, name)
-
- return {"ITERATOR": _saveable_factory}
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index acc605247f..a753d77580 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -306,6 +307,19 @@ class IteratorTest(test.TestCase):
checkpoint.restore(save_path)
self.assertEqual(2, iterator.get_next().numpy())
+ def testRestoreInReconstructedIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ dataset = Dataset.range(10)
+ for i in range(5):
+ iterator = datasets.Iterator(dataset)
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ checkpoint.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
+ for j in range(2):
+ self.assertEqual(i * 2 + j, iterator.get_next().numpy())
+ checkpoint.save(file_prefix=checkpoint_prefix)
+
class DatasetConstructorBenchmark(test.Benchmark):
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 12155a459c..6f02c90368 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -15,8 +15,6 @@ py_library(
"//tensorflow/contrib/eager/python/examples/revnet:config",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
- "//tensorflow/contrib/eager/python/examples/sagan",
- "//tensorflow/contrib/eager/python/examples/sagan:config",
"//tensorflow/contrib/eager/python/examples/spinn:data",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
index e81351b1b1..34a9984b0e 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
@@ -211,8 +211,7 @@ class ImageNetInput(object):
dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input)
dataset = dataset.prefetch(batch_size)
- dataset = dataset.apply(
- tf.contrib.data.batch_and_drop_remainder(batch_size))
+ dataset = dataset.batch(batch_size, drop_remainder=True)
if self.transpose_input:
dataset = dataset.map(
lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels),
diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD
deleted file mode 100644
index b470a41d81..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/BUILD
+++ /dev/null
@@ -1,59 +0,0 @@
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-# Model
-py_library(
- name = "config",
- srcs = ["config.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "ops",
- srcs = ["ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "sagan",
- srcs = ["sagan.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":ops",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-# Tests
-cuda_py_test(
- name = "ops_test",
- size = "small",
- srcs = ["ops_test.py"],
- additional_deps = [
- ":ops",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-cuda_py_test(
- name = "sagan_test",
- size = "large",
- srcs = ["sagan_test.py"],
- additional_deps = [
- ":config",
- ":sagan",
- "//tensorflow:tensorflow_py",
- ],
- tags = [
- "optonly",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py
deleted file mode 100644
index 1967bbd867..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/config.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Self-attention generative adversarial with eager execution.
-
-Configuration in format of tf.contrib.training.HParams.
-Supports default 128x128 ImageNet.
-
-Reference [Self-Attention Generative Adversarial
-Networks](https://arxiv.org/pdf/1805.08318.pdf)
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-tfe = tf.contrib.eager
-
-
-def get_hparams_imagenet():
- """Configurations to train SAGAN on 128x128 ImageNet dataset."""
- config = tf.contrib.training.HParams()
- if tf.test.is_gpu_available():
- config.add_hparam("image_shape", (3, 128, 128))
- config.add_hparam("data_format", "channels_first")
- config.add_hparam("g_init_shape", (512, 4, 4))
- else:
- config.add_hparam("image_shape", (128, 128, 3))
- config.add_hparam("data_format", "channels_first")
- config.add_hparam("g_init_shape", (4, 4, 512))
-
- config.add_hparam("latent_dim", 128)
- config.add_hparam("update_g_once_every", 1)
- config.add_hparam("batch_size", 64)
- config.add_hparam("d_init_filters", 32)
- config.add_hparam("num_upsamples", 5)
- # (512, 4, 4) -> (3, 128, 128)
- return config
-
-
-def get_hparams_mock():
- """Configurations of smaller networks for testing."""
- config = tf.contrib.training.HParams()
- if tf.test.is_gpu_available():
- config.add_hparam("image_shape", (3, 16, 16))
- config.add_hparam("data_format", "channels_first")
- config.add_hparam("g_init_shape", (32, 2, 2))
- else:
- config.add_hparam("image_shape", (16, 16, 3))
- config.add_hparam("data_format", "channels_last")
- config.add_hparam("g_init_shape", (2, 2, 32))
-
- config.add_hparam("latent_dim", 16)
- config.add_hparam("update_g_once_every", 1)
- config.add_hparam("batch_size", 2)
- config.add_hparam("d_init_filters", 4)
- config.add_hparam("num_upsamples", 3)
- # (32, 2, 2) -> (3, 16, 16)
- return config
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py
deleted file mode 100644
index 9a03cab1d1..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/ops.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Self-attention generative adversarial with eager execution.
-
-Auxiliary operations.
-
-Reference [Self-Attention Generative Adversarial
-Networks](https://arxiv.org/pdf/1805.08318.pdf)
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-
-def flatten_hw(x, data_format="channels_first"):
- """Flatten the input tensor across height and width dimensions."""
- if data_format == "channels_last":
- x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
-
- old_shape = tf.shape(x)
- new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]]
-
- return tf.reshape(x, new_shape)
-
-
-def broaden_hw(x, h, w, c, data_format="channels_first"):
- """Broaden dimension so that output has height and width."""
- if data_format == "channels_first":
- shape = [-1, c, h, w]
- else:
- shape = [-1, h, w, c]
-
- return tf.reshape(x, shape)
-
-
-class BroadenHW(tf.keras.layers.Layer):
- """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`."""
-
- def __init__(self, h, w, c, data_format="channels_first"):
- super(BroadenHW, self).__init__()
- self.h = h
- self.w = w
- self.c = c
- self.data_format = data_format
-
- def call(self, x):
- return broaden_hw(
- x, h=self.h, w=self.w, c=self.c, data_format=self.data_format)
-
- def compute_output_shape(self, input_shape):
- input_shape = tf.TensorShape(input_shape).as_list()
- if self.data_format == "channels_first":
- output_shape = (input_shape[0], self.c, self.h, self.w)
- else:
- output_shape = (input_shape[0], self.h, self.w, self.c)
-
- return tf.TensorShape(output_shape)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
deleted file mode 100644
index 3454985904..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for auxiliary operations."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.sagan import ops
-
-
-class OpsTest(tf.test.TestCase):
-
- def test_flatten_hw(self):
- """Test `flatten_hw` function with mock object."""
-
- batch_size = 1
- # Default NCHW format
- if tf.test.is_gpu_available():
- x = tf.random_normal(shape=(batch_size, 3, 4, 4))
- y = ops.flatten_hw(x, data_format="channels_first")
- self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
-
- # NHWC format
- x = tf.random_normal(shape=(batch_size, 4, 4, 3))
- y = ops.flatten_hw(x, data_format="channels_last")
- self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
-
- def test_broaden_hw(self):
- """Test `broaden_hw` function with mock object."""
-
- batch_size = 1
- # NHWC format
- x = tf.random_normal(shape=[batch_size, 4 * 4 * 16])
- y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last")
- self.assertEqual(y.shape, (batch_size, 4, 4, 16))
-
- # Default NCHW format
- if tf.test.is_gpu_available():
- y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first")
- self.assertEqual(y.shape, (batch_size, 16, 4, 4))
-
-
-if __name__ == "__main__":
- tf.enable_eager_execution()
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
deleted file mode 100644
index 8130414985..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py
+++ /dev/null
@@ -1,232 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Self-attention generative adversarial with eager execution.
-
-Code for main model.
-
-Reference [Self-Attention Generative Adversarial
-Networks](https://arxiv.org/pdf/1805.08318.pdf)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.sagan import ops
-tfe = tf.contrib.eager
-
-
-class SelfAttentionModule(tf.keras.Model):
- """Self-attention module composed of convolutional layers."""
-
- def __init__(self,
- attention_features,
- original_features,
- data_format="channels_first"):
- """Initialize the module.
-
- Args:
- attention_features: Number of filters for the attention computation.
- original_features: Number of filters of the original Tensor.
- data_format: Either 'channels_first' or 'channels_last'
- """
- super(SelfAttentionModule, self).__init__()
- self.data_format = data_format
- # Matrix multiplication implemented as 2D Convolution
- self.f = tf.keras.layers.Conv2D(
- filters=attention_features,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format)
- self.g = tf.keras.layers.Conv2D(
- filters=attention_features,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format)
- self.h = tf.keras.layers.Conv2D(
- filters=original_features,
- kernel_size=1,
- strides=(1, 1),
- data_format=data_format)
- self.scale = tf.Variable(0., trainable=True)
-
- def call(self, x):
- f = self.f(x)
- g = self.g(x)
- h = self.h(x)
-
- f_flatten = ops.flatten_hw(f, data_format=self.data_format)
- g_flatten = ops.flatten_hw(g, data_format=self.data_format)
- h_flatten = ops.flatten_hw(h, data_format=self.data_format)
-
- s = tf.matmul(g_flatten, f_flatten, transpose_b=True)
- b = tf.nn.softmax(s, axis=-1)
- o = tf.matmul(b, h_flatten)
- y = self.scale * tf.reshape(o, tf.shape(x)) + x
-
- return y
-
- def compute_output_shape(self, input_shape):
- return input_shape
-
-
-class SAGAN(tf.contrib.checkpoint.Checkpointable):
- """Self-attention generative adversarial network."""
-
- def __init__(self, config):
- """Initialize the model.
-
- Args:
- config: tf.contrib.training.HParams object; specifies hyperparameters
- """
- super(SAGAN, self).__init__()
- self.config = config
- self.generator = self._construct_generator()
- self.discriminator = self._construct_discriminator()
-
- def _construct_generator(self):
- """Construct generator."""
- # TODO(lxuechen): Add spectral normalization for WGAN
- axis = 1 if self.config.data_format == "channels_first" else 3
-
- generator = tf.keras.Sequential()
- generator.add(
- tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,)))
- generator.add(
- tf.keras.layers.Dense(
- units=np.prod(self.config.g_init_shape), activation=tf.nn.relu))
-
- if self.config.data_format == "channels_first":
- c, h, w = self.config.g_init_shape
- else:
- h, w, c = self.config.g_init_shape
-
- # Reshape to NHWC/NCHW
- generator.add(
- ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format))
-
- filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)]
- filters_list[-1] = 3 # Standard RGB images
-
- for filters in filters_list[:len(filters_list) // 2]:
- generator.add(
- tf.keras.layers.Conv2DTranspose(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- use_bias=False,
- padding="SAME",
- data_format=self.config.data_format))
- generator.add(tf.keras.layers.BatchNormalization(axis=axis))
- generator.add(tf.keras.layers.Activation("relu"))
-
- # pylint: disable=undefined-loop-variable
- generator.add(
- SelfAttentionModule(
- original_features=filters,
- attention_features=filters // 8,
- data_format=self.config.data_format))
- # pylint: enable=undefined-loop-variable
-
- for filters in filters_list[len(filters_list) // 2:]:
- generator.add(
- tf.keras.layers.Conv2DTranspose(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- use_bias=False,
- padding="SAME",
- data_format=self.config.data_format))
- if filters == 3:
- # Assume Image rescaled to [-1, 1]
- generator.add(tf.keras.layers.Activation("tanh"))
- else:
- generator.add(tf.keras.layers.BatchNormalization(axis=axis))
- generator.add(tf.keras.layers.Activation("relu"))
-
- return generator
-
- def _construct_discriminator(self):
- """Construct discriminator."""
- # TODO(lxuechen): Add spectral normalization for WGAN
- discriminator = tf.keras.Sequential()
- discriminator.add(
- tf.keras.layers.InputLayer(input_shape=self.config.image_shape))
-
- filters_list = [
- self.config.d_init_filters * 2**p
- for p in range(self.config.num_upsamples)
- ]
-
- for filters in filters_list[:(len(filters_list) + 1) // 2]:
- discriminator.add(
- tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- padding="SAME",
- data_format=self.config.data_format))
- discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
-
- # pylint: disable=undefined-loop-variable
- discriminator.add(
- SelfAttentionModule(
- original_features=filters,
- attention_features=filters // 8,
- data_format=self.config.data_format))
- # pylint: enable=undefined-loop-variable
-
- for filters in filters_list[(len(filters_list) + 1) // 2:]:
- discriminator.add(
- tf.keras.layers.Conv2D(
- filters=filters,
- kernel_size=4,
- strides=(2, 2),
- padding="SAME",
- data_format=self.config.data_format))
- discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
-
- discriminator.add(tf.keras.layers.Flatten())
- discriminator.add(tf.keras.layers.Dense(units=1))
-
- return discriminator
-
- def compute_loss_and_grads(self, real_images, noise, training=True):
- """Compute loss and gradients for both generator and discriminator."""
- # TODO(lxuechen): Add gradient penalty for discriminator
- with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
- real_logits = self.discriminator(real_images, training=training)
-
- fake_images = self.generator.call(noise, training=training)
- fake_logits = self.discriminator.call(fake_images)
-
- g_loss = self.compute_g_loss(fake_logits)
- d_loss = self.compute_d_loss(fake_logits, real_logits)
-
- g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
- d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
-
- return g_loss, d_loss, g_grads, d_grads
-
- def compute_g_loss(self, fake_logits):
- return -tf.reduce_mean(fake_logits) # Hinge loss
-
- def compute_d_loss(self, fake_logits, real_logits):
- # Hinge loss
- real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits))
- fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits))
- return real_loss + fake_loss
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
deleted file mode 100644
index 1834594510..0000000000
--- a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for self-attention generative adversarial network."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-from tensorflow.contrib.eager.python.examples.sagan import config as config_
-from tensorflow.contrib.eager.python.examples.sagan import sagan
-tfe = tf.contrib.eager
-
-
-class SAGANTest(tf.test.TestCase):
-
- def setUp(self):
- super(SAGANTest, self).setUp()
- config = config_.get_hparams_mock()
- self.noise_shape = (config.batch_size, config.latent_dim)
- self.logits_shape = (config.batch_size, 1)
- self.images_shape = (config.batch_size,) + config.image_shape
-
- self.model = sagan.SAGAN(config=config)
- self.noise = tf.random_normal(shape=self.noise_shape)
- self.real_images = tf.random_normal(shape=self.images_shape)
- self.config = config
-
- def tearDown(self):
- del self.model
- del self.noise
- del self.real_images
- super(SAGANTest, self).tearDown()
-
- def test_generator_call(self):
- """Test `generator.__call__` function."""
- fake_images = self.model.generator(self.noise, training=False)
- self.assertEqual(fake_images.shape, self.images_shape)
-
- def test_generator_call_defun(self):
- """Test `generator.__call__` function with defun."""
- call_ = tfe.defun(self.model.generator.__call__)
- fake_images = call_(self.noise, training=False)
- self.assertEqual(fake_images.shape, self.images_shape)
-
- def test_discriminator_call(self):
- """Test `discriminator.__call__` function."""
- real_logits = self.model.discriminator(self.real_images)
- self.assertEqual(real_logits.shape, self.logits_shape)
-
- def test_discriminator_call_defun(self):
- """Test `discriminator.__call__` function with defun."""
- call_ = tfe.defun(self.model.discriminator.__call__)
- real_logits = call_(self.real_images)
- self.assertEqual(real_logits.shape, self.logits_shape)
-
- def test_compute_loss_and_grads(self):
- """Test `compute_loss_and_grads` function."""
- g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads(
- self.real_images, self.noise, training=False)
- self.assertEqual(g_loss.shape, ())
- self.assertEqual(d_loss.shape, ())
- self.assertTrue(isinstance(g_grads, list))
- self.assertTrue(isinstance(d_grads, list))
- g_vars = self.model.generator.trainable_variables
- d_vars = self.model.discriminator.trainable_variables
-
- self.assertEqual(len(g_grads), len(g_vars))
- self.assertEqual(len(d_grads), len(d_vars))
-
- def test_compute_loss_and_grads_defun(self):
- """Test `compute_loss_and_grads` function with defun."""
- compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads)
- g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads(
- self.real_images, self.noise, training=False)
- self.assertEqual(g_loss.shape, ())
- self.assertEqual(d_loss.shape, ())
- self.assertTrue(isinstance(g_grads, list))
- self.assertTrue(isinstance(d_grads, list))
- g_vars = self.model.generator.trainable_variables
- d_vars = self.model.discriminator.trainable_variables
-
- self.assertEqual(len(g_grads), len(g_vars))
- self.assertEqual(len(d_grads), len(d_vars))
-
-
-if __name__ == "__main__":
- tf.enable_eager_execution()
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
index 8ac553e0ae..d18a097063 100644
--- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
+++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py
@@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.checkpointable import util as checkpointable_utils
# pylint: enable=g-bad-import-order
@@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase):
# 5. Verify that checkpoints exist and contains all the expected variables.
self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*")))
object_graph = checkpointable_utils.object_metadata(
- saver.latest_checkpoint(config.logdir))
+ checkpoint_management.latest_checkpoint(config.logdir))
ckpt_variable_names = set()
for node in object_graph.nodes:
for attribute in node.attributes:
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index dc49383c5c..918a7e2bc7 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -133,6 +133,7 @@ _nest_allowed_symbols = [
'flatten_dict_items',
'pack_sequence_as',
'map_structure',
+ 'map_structure_with_paths',
'assert_shallow_structure',
'flatten_up_to',
'map_structure_up_to',
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
index 9e356dd965..e7184a01fb 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py
@@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as train
__all__ = [
@@ -40,7 +40,7 @@ __all__ = [
def _get_checkpoint_filename(filepattern):
"""Returns checkpoint filename given directory or specific filepattern."""
if gfile.IsDirectory(filepattern):
- return saver.latest_checkpoint(filepattern)
+ return checkpoint_management.latest_checkpoint(filepattern)
return filepattern
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index f3bbf6b4d7..7e6a0f14f6 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -174,7 +174,7 @@ class GdrMemoryManager : public RemoteMemoryManager {
// Client side endpoints
mutex client_mu_;
std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
- GUARDED_BY(cient_mu_);
+ GUARDED_BY(client_mu_);
// Managed memory regions
mutex alloc_mu_;
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 7a026a15e4..c1de42782e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
@@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator):
if not checkpoint_path:
# Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index f8a3709ee5..08e907a608 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
@@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
# Load and cache the path of the most recent checkpoint to avoid duplicate
# searches on GCS.
logging.info("Checking for checkpoint in %s", self._model_dir)
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.warning("Skipping evaluation and export since model has not been "
@@ -516,7 +516,8 @@ class Experiment(object):
start = time.time()
error_msg = None
- latest_path = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
@@ -778,7 +779,8 @@ class Experiment(object):
saving_listeners=self._saving_listeners)
logging.info("Evaluating model now.")
- latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_checkpoint = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
eval_result = self._call_evaluate(
input_fn=self._eval_input_fn,
steps=self._eval_steps,
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index 0d039d593b..df156da3f4 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
@@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 77f7c73d54..3d691d4340 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as core_summary
-from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
@@ -735,7 +735,8 @@ class ValidationMonitor(EveryN):
return False
self._last_checkpoint_check_time = current_time
# Check that we are not running evaluation on the same checkpoint.
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
@@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN):
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.info("Skipping export at the end since model has not been saved "
"yet.")
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index 5c34d0ddb0..ff1da32c21 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -39,9 +39,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver
from tensorflow.python.training import training_util
@@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase):
mock_latest_checkpoint.assert_called_with(model_dir)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_early_stopping_rounds(self,
mock_latest_checkpoint,
mock_estimator_class):
@@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase):
monitor.epoch_end(epoch=0)
monitor.end()
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
model_dir = 'model/dir'
@@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase):
expected_best_metrics={'loss': 42.0, 'auc': 0.5})
monitor.post_step(step=step, session=None)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_fail_with_core_estimator_and_metrics(
self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index 3eacac7a3d..0144b93814 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util
@@ -298,7 +299,8 @@ def _export_estimator(estimator,
# If checkpoint_path is specified, use the specified checkpoint path.
checkpoint_path = (checkpoint_path or
- tf_saver.latest_checkpoint(estimator._model_dir))
+ checkpoint_management.latest_checkpoint(
+ estimator._model_dir))
with ops.Graph().as_default() as g:
training_util.create_global_step(g)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index f8106d1e4a..66af6833da 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@@ -714,7 +714,8 @@ def make_best_model_export_strategy(
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
- checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 7d7dd6b708..1e6f1e7da2 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -125,10 +125,22 @@ cc_library(
"graph_info.cc",
"interpreter.cc",
"model.cc",
- "nnapi_delegate.cc",
"op_resolver.cc",
"optional_debug_tools.cc",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "nnapi_delegate.cc",
+ "mmap_allocation.cc",
+ ],
+ "//tensorflow:windows": [
+ "nnapi_delegate_disabled.cc",
+ "mmap_allocation_disabled.cc",
+ ],
+ "//conditions:default": [
+ "nnapi_delegate_disabled.cc",
+ "mmap_allocation.cc",
+ ],
+ }),
hdrs = [
"allocation.h",
"context.h",
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile
index df5954744a..9cc8f10b42 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/Makefile
@@ -95,6 +95,7 @@ ARFLAGS := -r
INCLUDES := \
-I. \
-I$(MAKEFILE_DIR)/../../../ \
+-I$(MAKEFILE_DIR)/../../../../ \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
@@ -176,8 +177,12 @@ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \
$(MINIMAL_SRCS)
ifeq ($(BUILD_TYPE),micro)
CORE_CC_EXCLUDE_SRCS += \
-tensorflow/contrib/lite/model.cc \
+tensorflow/contrib/lite/mmap_allocation.cc \
tensorflow/contrib/lite/nnapi_delegate.cc
+else
+CORE_CC_EXCLUDE_SRCS += \
+tensorflow/contrib/lite/mmap_allocation_disabled.cc \
+tensorflow/contrib/lite/nnapi_delegate_disabled.cc
endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
@@ -214,8 +219,12 @@ all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
# The target that's compiled for micro-controllers
micro: $(LIB_PATH)
+# Hack for generating schema file bypassing flatbuffer parsing
+tensorflow/contrib/lite/schema/schema_generated.h:
+ @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h
+
# Gathers together all the objects we've compiled into a single '.a' archive.
-$(LIB_PATH): $(LIB_OBJS)
+$(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index ef6c14f085..8946261814 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -13,61 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <fcntl.h>
-#ifndef TFLITE_MCU
-#include <sys/mman.h>
-#endif
+#include "tensorflow/contrib/lite/allocation.h"
+
#include <sys/stat.h>
#include <sys/types.h>
-#include <unistd.h>
#include <cassert>
#include <cstdarg>
#include <cstdint>
#include <cstring>
#include <utility>
-#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
-#ifndef TFLITE_MCU
-#include "tensorflow/contrib/lite/nnapi_delegate.h"
-#endif
namespace tflite {
#ifndef TFLITE_MCU
-MMAPAllocation::MMAPAllocation(const char* filename,
- ErrorReporter* error_reporter)
- : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) {
- mmap_fd_ = open(filename, O_RDONLY);
- if (mmap_fd_ == -1) {
- error_reporter_->Report("Could not open '%s'.", filename);
- return;
- }
- struct stat sb;
- fstat(mmap_fd_, &sb);
- buffer_size_bytes_ = sb.st_size;
- mmapped_buffer_ =
- mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
- if (mmapped_buffer_ == MAP_FAILED) {
- error_reporter_->Report("Mmap of '%s' failed.", filename);
- return;
- }
-}
-
-MMAPAllocation::~MMAPAllocation() {
- if (valid()) {
- munmap(const_cast<void*>(mmapped_buffer_), buffer_size_bytes_);
- }
- if (mmap_fd_ != -1) close(mmap_fd_);
-}
-
-const void* MMAPAllocation::base() const { return mmapped_buffer_; }
-
-size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; }
-
-bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; }
-
FileCopyAllocation::FileCopyAllocation(const char* filename,
ErrorReporter* error_reporter)
: Allocation(error_reporter) {
@@ -111,6 +72,7 @@ const void* FileCopyAllocation::base() const { return copied_buffer_.get(); }
size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; }
bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; }
+#endif
MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
ErrorReporter* error_reporter)
@@ -118,7 +80,6 @@ MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
buffer_ = ptr;
buffer_size_bytes_ = num_bytes;
}
-#endif
MemoryAllocation::~MemoryAllocation() {}
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 827ea86503..121f3d2646 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -52,6 +52,8 @@ class MMAPAllocation : public Allocation {
size_t bytes() const override;
bool valid() const override;
+ static bool IsSupported();
+
protected:
// Data required for mmap.
int mmap_fd_ = -1; // mmap file descriptor
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 0b6568fd2f..8a8eb98568 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -111,6 +111,8 @@ typedef enum {
kTfLiteBuiltinPack = 83,
kTfLiteBuiltinLogicalOr = 84,
kTfLiteBuiltinOneHot = 85,
+ kTfLiteBuiltinLogicalAnd = 86,
+ kTfLiteBuiltinLogicalNot = 87,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index a28707382e..6242c09efe 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -39,6 +39,41 @@ cc_test(
)
cc_library(
+ name = "delegate",
+ srcs = [
+ "delegate.cc",
+ ],
+ hdrs = [
+ "delegate.h",
+ ],
+ deps = [
+ ":buffer_map",
+ ":delegate_data",
+ ":kernel",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "delegate_test",
+ size = "small",
+ srcs = ["delegate_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":delegate",
+ ":test_util",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
name = "delegate_data",
srcs = ["delegate_data.cc"],
hdrs = ["delegate_data.h"],
@@ -96,10 +131,20 @@ cc_test(
deps = [
":delegate_data",
":kernel",
+ ":test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "test_util",
+ testonly = True,
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ deps = [
+ "//tensorflow/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
- "//tensorflow/contrib/lite/testing:util",
"@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest",
"@flatbuffers",
],
)
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc
new file mode 100644
index 0000000000..673859da48
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+
+#include <vector>
+
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tflite {
+namespace eager {
+namespace delegate {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
+ // Get the nodes in the current execution plan.
+ TfLiteIntArray* plan;
+ TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
+
+ // Add all custom ops starting with "Eager" to list of supported nodes.
+ std::vector<int> supported_nodes;
+ for (int node_index : TfLiteIntArrayView(plan)) {
+ TfLiteNode* node;
+ TfLiteRegistration* registration;
+ TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
+ context, node_index, &node, &registration));
+
+ if (registration->custom_name &&
+ strncmp(registration->custom_name, "Eager", 5) == 0) {
+ supported_nodes.push_back(node_index);
+ }
+ }
+
+ // Request TFLite to partition the graph and make kernels for each independent
+ // subgraph.
+ TfLiteIntArray* size_and_nodes =
+ ConvertVectorToTfLiteIntArray(supported_nodes);
+ context->ReplaceSubgraphsWithDelegateKernels(context, GetKernel(),
+ size_and_nodes, delegate);
+ TfLiteIntArrayFree(size_and_nodes);
+ return kTfLiteOk;
+}
+
+TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle, void* data,
+ size_t size) {
+ // TODO(nupurgarg): Make BufferMap unique to each interpreter in order to
+ // support multiple interpreters using a single delegate.
+ BufferMap* buffer_map =
+ reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap();
+
+ if (!buffer_map->HasTensor(buffer_handle)) {
+ fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle);
+ return kTfLiteError;
+ }
+
+ tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle);
+ tensorflow::StringPiece t_data = t.tensor_data();
+
+ if (size != t_data.size()) {
+ fprintf(stderr, "Not enough space to store TensorFlow's aligned buffer.\n");
+ return kTfLiteError;
+ }
+
+ memcpy(data, t_data.data(), t_data.size());
+ return kTfLiteOk;
+}
+
+} // namespace delegate
+} // namespace eager
+
+EagerDelegate::EagerDelegate() {
+ if (!eager::DelegateData::Create(&delegate_data_).ok()) {
+ fprintf(stderr, "Unable to initialize TensorFlow context.\n");
+ return;
+ }
+
+ delegate_.reset(new TfLiteDelegate{
+ /*data_=*/delegate_data_.get(),
+ /*nullptr,*/ &eager::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*CopyToBufferHandle=*/nullptr,
+ /*FreeBufferHandle=*/nullptr});
+}
+
+EagerDelegate::~EagerDelegate() {}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
new file mode 100644
index 0000000000..6259b35931
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+
+namespace tflite {
+
+// WARNING: This is an experimental interface that is subject to change.
+// Delegate that can be used to extract parts of a graph that are designed to be
+// executed by TensorFlow's runtime via Eager.
+//
+// The interpreter must be constructed after the EagerDelegate and destructed
+// before the EagerDelegate. This delegate can only be used with one
+// interpreter.
+//
+// Usage:
+// EagerDelegate delegate();
+// ... build interpreter ...
+//
+// delegate.Apply(interpreter);
+// ... run inference ...
+// ... destroy interpreter ...
+// ... destroy delegate ...
+class EagerDelegate {
+ public:
+ EagerDelegate();
+ ~EagerDelegate();
+
+ TfLiteStatus Apply(Interpreter* interpreter) {
+ return interpreter->ModifyGraphWithDelegate(delegate_.get(),
+ /*allow_dynamic_tensors=*/true);
+ }
+
+ private:
+ std::unique_ptr<eager::DelegateData> delegate_data_;
+ std::unique_ptr<TfLiteDelegate> delegate_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
new file mode 100644
index 0000000000..88fb34044e
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -0,0 +1,150 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+
+namespace tflite {
+namespace eager {
+namespace {
+
+using ::testing::ContainsRegex;
+using ::testing::ElementsAre;
+
+// TODO(nupurgarg): Add a test with multiple interpreters for one delegate.
+
+class DelegateTest : public testing::EagerModelTest {
+ public:
+ DelegateTest() {
+ // The delegate needs to be constructed before the interpreter because the
+ // interpreter references data contained in the delegate.
+ delegate_.reset(new EagerDelegate());
+ interpreter_.reset(new Interpreter(&error_reporter_));
+ }
+
+ ~DelegateTest() override {
+ // The delegate needs to be destructed after the interpreter because the
+ // interpreter references data contained in the delegate.
+ delete interpreter_.release();
+ delete delegate_.release();
+ }
+
+ void ConfigureDelegate() {
+ CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk);
+ }
+
+ private:
+ std::unique_ptr<EagerDelegate> delegate_;
+};
+
+TEST_F(DelegateTest, FullGraph) {
+ // Define the graph.
+ AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
+
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kAdd, {1, 4}, {6});
+ AddTfOp(testing::kAdd, {2, 5}, {7});
+ AddTfOp(testing::kMul, {6, 7}, {8});
+
+ // Apply the delegate.
+ ConfigureDelegate();
+
+ // Define inputs.
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
+ ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+}
+
+TEST_F(DelegateTest, MixedGraph) {
+ AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
+
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kAdd, {1, 4}, {6});
+ AddTfOp(testing::kAdd, {2, 5}, {7});
+ AddTfLiteMulOp({6, 7}, {8});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(3, {2, 2, 1});
+ SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
+ ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+}
+
+TEST_F(DelegateTest, SplitGraph) {
+ AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
+
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kAdd, {1, 2}, {3});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+
+ AddTfLiteMulOp({4, 5}, {6});
+
+ AddTfOp(testing::kUnpack, {6}, {7, 8});
+ AddTfOp(testing::kAdd, {7, 8}, {9});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2, 2, 1});
+ SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(9), ElementsAre(1));
+ ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
+}
+
+TEST_F(DelegateTest, OnlyTFLite) {
+ // Only TFLite single op model.
+ AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
+ AddTfLiteMulOp({0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(1, {2, 2, 1});
+ SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+ ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
+}
+
+} // namespace
+} // namespace eager
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
index 7d9dddef93..b7bfbb34e4 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
@@ -16,26 +16,16 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
-#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/kernels/test_util.h"
-#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
namespace tflite {
namespace eager {
namespace {
-using tensorflow::protobuf::TextFormat;
using ::testing::ContainsRegex;
using ::testing::ElementsAre;
-// We will use these are custom_names, so they need to be static.
-static const char kIdentity[] = "Identity";
-static const char kUnpack[] = "Unpack";
-static const char kAdd[] = "Add";
-static const char kMul[] = "Mul";
-
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
const std::vector<int>& supported_nodes) {
TfLiteIntArray* size_and_nodes =
@@ -46,39 +36,18 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
return kTfLiteOk;
}
-class KernelTest : public ::testing::Test {
+class KernelTest : public testing::EagerModelTest {
public:
KernelTest() {
CHECK(DelegateData::Create(&delegate_data_).ok());
interpreter_.reset(new Interpreter(&error_reporter_));
}
- bool Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-
- void SetValues(int tensor_index, const std::vector<float>& values) {
- float* v = interpreter_->typed_tensor<float>(tensor_index);
- for (float f : values) {
- *v++ = f;
- }
- }
-
- std::vector<float> GetValues(int tensor_index) {
- TfLiteTensor* o = interpreter_->tensor(tensor_index);
- return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
- }
-
- void SetShape(int tensor_index, const std::vector<int>& values) {
- ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
- ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
- }
-
- std::vector<int> GetShape(int tensor_index) {
- std::vector<int> result;
- auto* dims = interpreter_->tensor(tensor_index)->dims;
- for (int i = 0; i < dims->size; ++i) {
- result.push_back(dims->data[i]);
- }
- return result;
+ ~KernelTest() override {
+ // The data needs to be released before the interpreter because the
+ // interpreter references the data.
+ delegate_data_.reset();
+ interpreter_.reset();
}
template <typename T>
@@ -99,112 +68,20 @@ class KernelTest : public ::testing::Test {
&delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk);
}
- void AddOp(const char* name, const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
- auto attr = [](const string& key, const string& value) {
- return " attr{ key: '" + key + "' value {" + value + "}}";
- };
-
- string attributes;
- if (name == string(kUnpack)) {
- attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
- attr("axis", "i: 0");
- } else if (name == string(kIdentity)) {
- attributes = attr("T", "type: DT_FLOAT");
- } else if (name == string(kAdd)) {
- attributes = attr("T", "type: DT_FLOAT");
- } else if (name == string(kMul)) {
- attributes = attr("T", "type: DT_FLOAT");
- }
- AddTfOp(name, attributes, inputs, outputs);
- }
-
- void AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
- interpreter_->AddTensors(num_tensors);
- for (int i = 0; i < num_tensors; ++i) {
- TfLiteQuantizationParams quant;
- CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, kTfLiteFloat32,
- /*name=*/"",
- /*dims=*/{3}, quant),
- kTfLiteOk);
- }
-
- CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
- CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
- }
-
- const TestErrorReporter& error_reporter() const { return error_reporter_; }
-
- void AddTfLiteOp(const char* name, const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
- CHECK_EQ(string(name), kMul); // can only add MUL
- static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
- reg.builtin_code = BuiltinOperator_MUL;
- reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
- auto* i0 = &context->tensors[node->inputs->data[0]];
- auto* o = &context->tensors[node->outputs->data[0]];
- return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
- };
- reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
- auto* i0 = &context->tensors[node->inputs->data[0]];
- auto* i1 = &context->tensors[node->inputs->data[1]];
- auto* o = &context->tensors[node->outputs->data[0]];
- for (int i = 0; i < o->bytes / sizeof(float); ++i) {
- o->data.f[i] = i0->data.f[i] * i1->data.f[i];
- }
- return kTfLiteOk;
- };
-
- CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
- nullptr, &reg),
- kTfLiteOk);
- }
-
private:
- void AddTfOp(const char* name, const string& nodedef_str,
- const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
- static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
- reg.builtin_code = BuiltinOperator_CUSTOM;
- reg.custom_name = name;
-
- tensorflow::NodeDef nodedef;
- CHECK(TextFormat::ParseFromString(nodedef_str + " op: '" + name + "'",
- &nodedef));
- string serialized_nodedef;
- CHECK(nodedef.SerializeToString(&serialized_nodedef));
- flexbuffers::Builder fbb;
- fbb.Vector([&]() {
- fbb.String(nodedef.op());
- fbb.String(serialized_nodedef);
- });
- fbb.Finish();
-
- flexbuffers_.push_back(fbb.GetBuffer());
- auto& buffer = flexbuffers_.back();
- CHECK_EQ(interpreter_->AddNodeWithParameters(
- inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
- buffer.size(), nullptr, &reg),
- kTfLiteOk);
- }
-
- std::unique_ptr<Interpreter> interpreter_;
std::unique_ptr<DelegateData> delegate_data_;
TfLiteDelegate delegate_;
- std::vector<std::vector<uint8_t>> flexbuffers_;
- TestErrorReporter error_reporter_;
};
TEST_F(KernelTest, FullGraph) {
// Define the graph.
- AddTensors(9, {0, 3}, {8});
+ AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
- AddOp(kUnpack, {0}, {1, 2});
- AddOp(kUnpack, {3}, {4, 5});
- AddOp(kAdd, {1, 4}, {6});
- AddOp(kAdd, {2, 5}, {7});
- AddOp(kMul, {6, 7}, {8});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kAdd, {1, 4}, {6});
+ AddTfOp(testing::kAdd, {2, 5}, {7});
+ AddTfOp(testing::kMul, {6, 7}, {8});
// Apply Delegate.
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
@@ -224,8 +101,8 @@ TEST_F(KernelTest, FullGraph) {
}
TEST_F(KernelTest, BadTensorFlowOp) {
- AddTensors(2, {0}, {1});
- AddOp("NonExistentOp", {0}, {1});
+ AddTensors(2, {0}, {1}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kNonExistent, {0}, {1});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0});
@@ -240,8 +117,8 @@ TEST_F(KernelTest, BadTensorFlowOp) {
}
TEST_F(KernelTest, BadNumberOfOutputs) {
- AddTensors(3, {0}, {1, 2});
- AddOp(kIdentity, {0}, {1, 2});
+ AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kIdentity, {0}, {1, 2});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0});
@@ -256,10 +133,10 @@ TEST_F(KernelTest, BadNumberOfOutputs) {
}
TEST_F(KernelTest, IncompatibleNodeDef) {
- AddTensors(2, {0}, {1});
+ AddTensors(2, {0}, {1}, kTfLiteFloat32, {3});
- // Cast is a TF op, but we don't add the proper nodedef to it in AddOp.
- AddOp("Cast", {0}, {1});
+ // Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp.
+ AddTfOp(testing::kIncompatibleNodeDef, {0}, {1});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0});
@@ -274,11 +151,11 @@ TEST_F(KernelTest, IncompatibleNodeDef) {
}
TEST_F(KernelTest, WrongSetOfNodes) {
- AddTensors(4, {0}, {3});
- AddOp(kUnpack, {0}, {1, 2});
- AddTfLiteOp(kMul, {1, 2}, {3});
+ AddTensors(4, {0}, {3}, kTfLiteFloat32, {3});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfLiteMulOp({1, 2}, {3});
- // Specify that kMul (#1) is supported when it actually isn't.
+ // Specify that testing::kMul (#1) is supported when it actually isn't.
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1});
});
@@ -292,13 +169,13 @@ TEST_F(KernelTest, WrongSetOfNodes) {
}
TEST_F(KernelTest, MixedGraph) {
- AddTensors(9, {0, 3}, {8});
+ AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
- AddOp(kUnpack, {0}, {1, 2});
- AddOp(kUnpack, {3}, {4, 5});
- AddOp(kAdd, {1, 4}, {6});
- AddOp(kAdd, {2, 5}, {7});
- AddTfLiteOp(kMul, {6, 7}, {8});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kAdd, {1, 4}, {6});
+ AddTfOp(testing::kAdd, {2, 5}, {7});
+ AddTfLiteMulOp({6, 7}, {8});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1, 2, 3});
@@ -316,16 +193,16 @@ TEST_F(KernelTest, MixedGraph) {
}
TEST_F(KernelTest, SplitGraph) {
- AddTensors(10, {0}, {9});
+ AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
- AddOp(kUnpack, {0}, {1, 2});
- AddOp(kAdd, {1, 2}, {3});
- AddOp(kUnpack, {3}, {4, 5});
+ AddTfOp(testing::kUnpack, {0}, {1, 2});
+ AddTfOp(testing::kAdd, {1, 2}, {3});
+ AddTfOp(testing::kUnpack, {3}, {4, 5});
- AddTfLiteOp(kMul, {4, 5}, {6});
+ AddTfLiteMulOp({4, 5}, {6});
- AddOp(kUnpack, {6}, {7, 8});
- AddOp(kAdd, {7, 8}, {9});
+ AddTfOp(testing::kUnpack, {6}, {7, 8});
+ AddTfOp(testing::kAdd, {7, 8}, {9});
ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) {
return GenericPrepare(context, delegate, {0, 1, 2, 4, 5});
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
new file mode 100644
index 0000000000..80acf5d995
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -0,0 +1,154 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+
+#include "absl/memory/memory.h"
+#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
+
+namespace tflite {
+namespace eager {
+namespace testing {
+
+bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
+
+void EagerModelTest::SetValues(int tensor_index,
+ const std::vector<float>& values) {
+ float* v = interpreter_->typed_tensor<float>(tensor_index);
+ for (float f : values) {
+ *v++ = f;
+ }
+}
+
+std::vector<float> EagerModelTest::GetValues(int tensor_index) {
+ TfLiteTensor* o = interpreter_->tensor(tensor_index);
+ return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
+}
+
+void EagerModelTest::SetShape(int tensor_index,
+ const std::vector<int>& values) {
+ ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
+ ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+}
+
+std::vector<int> EagerModelTest::GetShape(int tensor_index) {
+ std::vector<int> result;
+ auto* dims = interpreter_->tensor(tensor_index)->dims;
+ result.reserve(dims->size);
+ for (int i = 0; i < dims->size; ++i) {
+ result.push_back(dims->data[i]);
+ }
+ return result;
+}
+
+void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs,
+ const TfLiteType& type,
+ const std::vector<int>& dims) {
+ interpreter_->AddTensors(num_tensors);
+ for (int i = 0; i < num_tensors; ++i) {
+ TfLiteQuantizationParams quant;
+ CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
+ /*name=*/"",
+ /*dims=*/dims, quant),
+ kTfLiteOk);
+ }
+
+ CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
+ CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
+}
+
+void EagerModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.builtin_code = BuiltinOperator_MUL;
+ reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* i0 = &context->tensors[node->inputs->data[0]];
+ auto* o = &context->tensors[node->outputs->data[0]];
+ return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
+ };
+ reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
+ auto* i0 = &context->tensors[node->inputs->data[0]];
+ auto* i1 = &context->tensors[node->inputs->data[1]];
+ auto* o = &context->tensors[node->outputs->data[0]];
+ for (int i = 0; i < o->bytes / sizeof(float); ++i) {
+ o->data.f[i] = i0->data.f[i] * i1->data.f[i];
+ }
+ return kTfLiteOk;
+ };
+
+ CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
+ nullptr, &reg),
+ kTfLiteOk);
+}
+
+void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ auto attr = [](const string& key, const string& value) {
+ return " attr{ key: '" + key + "' value {" + value + "}}";
+ };
+
+ if (op == kUnpack) {
+ string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
+ attr("axis", "i: 0");
+ AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
+ } else if (op == kIdentity) {
+ string attributes = attr("T", "type: DT_FLOAT");
+ AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
+ } else if (op == kAdd) {
+ string attributes = attr("T", "type: DT_FLOAT");
+ AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
+ } else if (op == kMul) {
+ string attributes = attr("T", "type: DT_FLOAT");
+ AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
+ } else if (op == kNonExistent) {
+ AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
+ } else if (op == kIncompatibleNodeDef) {
+ // "Cast" op is created without attributes - making it incompatible.
+ AddTfOp("EagerCast", "Cast", "", inputs, outputs);
+ }
+}
+
+void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
+ const string& nodedef_str,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
+ static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ reg.builtin_code = BuiltinOperator_CUSTOM;
+ reg.custom_name = tflite_name;
+
+ tensorflow::NodeDef nodedef;
+ CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
+ nodedef_str + " op: '" + tf_name + "'", &nodedef));
+ string serialized_nodedef;
+ CHECK(nodedef.SerializeToString(&serialized_nodedef));
+ flexbuffers::Builder fbb;
+ fbb.Vector([&]() {
+ fbb.String(nodedef.op());
+ fbb.String(serialized_nodedef);
+ });
+ fbb.Finish();
+
+ flexbuffers_.push_back(fbb.GetBuffer());
+ auto& buffer = flexbuffers_.back();
+ CHECK_EQ(interpreter_->AddNodeWithParameters(
+ inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
+ buffer.size(), nullptr, &reg),
+ kTfLiteOk);
+}
+
+} // namespace testing
+} // namespace eager
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h
new file mode 100644
index 0000000000..0eab9e1135
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.h
@@ -0,0 +1,97 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace eager {
+namespace testing {
+
+enum TfOpType {
+ kUnpack,
+ kIdentity,
+ kAdd,
+ kMul,
+ // Represents an op that does not exist in TensorFlow.
+ kNonExistent,
+ // Represents an valid TensorFlow op where the NodeDef is incompatible.
+ kIncompatibleNodeDef,
+};
+
+// This class creates models with TF and TFLite ops. In order to use this class
+// to test the Eager delegate, implement a function that calls
+// interpreter->ModifyGraphWithDelegate.
+class EagerModelTest : public ::testing::Test {
+ public:
+ EagerModelTest() {}
+ ~EagerModelTest() {}
+
+ bool Invoke();
+
+ // Sets the tensor's values at the given index.
+ void SetValues(int tensor_index, const std::vector<float>& values);
+
+ // Returns the tensor's values at the given index.
+ std::vector<float> GetValues(int tensor_index);
+
+ // Sets the tensor's shape at the given index.
+ void SetShape(int tensor_index, const std::vector<int>& values);
+
+ // Returns the tensor's shape at the given index.
+ std::vector<int> GetShape(int tensor_index);
+
+ const TestErrorReporter& error_reporter() const { return error_reporter_; }
+
+ // Adds `num_tensor` tensors to the model. `inputs` contains the indices of
+ // the input tensors and `outputs` contains the indices of the output
+ // tensors. All tensors are set to have `type` and `dims`.
+ void AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs, const TfLiteType& type,
+ const std::vector<int>& dims);
+
+ // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
+ // and `outputs` contains the indices of the output tensors.
+ void AddTfLiteMulOp(const std::vector<int>& inputs,
+ const std::vector<int>& outputs);
+
+ // Adds a TensorFlow op. `inputs` contains the indices of the
+ // input tensors and `outputs` contains the indices of the output tensors.
+ // This function is limited to the set of ops defined in TfOpType.
+ void AddTfOp(TfOpType op, const std::vector<int>& inputs,
+ const std::vector<int>& outputs);
+
+ protected:
+ std::unique_ptr<Interpreter> interpreter_;
+ TestErrorReporter error_reporter_;
+
+ private:
+ // Helper method to add a TensorFlow op. tflite_names needs to start with
+ // "Eager" in order to work with the Eager delegate.
+ void AddTfOp(const char* tflite_name, const string& tf_name,
+ const string& nodedef_str, const std::vector<int>& inputs,
+ const std::vector<int>& outputs);
+
+ std::vector<std::vector<uint8_t>> flexbuffers_;
+};
+
+} // namespace testing
+} // namespace eager
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
index e36218e4f1..6fdcf78b69 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
@@ -16,11 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/version.h"
+#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
@@ -28,8 +24,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/version.h"
-#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
-
namespace tflite {
namespace label_image {
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity
index 9397d8f27a..bcf24b89e3 100644
--- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity
@@ -154,7 +154,7 @@ Camera:
m_Enabled: 1
serializedVersion: 2
m_ClearFlags: 1
- m_BackGroundColor: {r: 0.19215687, g: 0.3019608, b: 0.4745098, a: 0}
+ m_BackGroundColor: {r: 0.21933319, g: 0.21933319, b: 0.21933319, a: 0}
m_NormalizedViewPortRect:
serializedVersion: 2
x: 0
@@ -195,6 +195,100 @@ Transform:
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
+--- !u!1 &871349752
+GameObject:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ serializedVersion: 5
+ m_Component:
+ - component: {fileID: 871349756}
+ - component: {fileID: 871349755}
+ - component: {fileID: 871349754}
+ - component: {fileID: 871349753}
+ m_Layer: 5
+ m_Name: Canvas
+ m_TagString: Untagged
+ m_Icon: {fileID: 0}
+ m_NavMeshLayer: 0
+ m_StaticEditorFlags: 0
+ m_IsActive: 1
+--- !u!114 &871349753
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 871349752}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 1301386320, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ m_IgnoreReversedGraphics: 1
+ m_BlockingObjects: 0
+ m_BlockingMask:
+ serializedVersion: 2
+ m_Bits: 4294967295
+--- !u!114 &871349754
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 871349752}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 1980459831, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ m_UiScaleMode: 0
+ m_ReferencePixelsPerUnit: 100
+ m_ScaleFactor: 1
+ m_ReferenceResolution: {x: 800, y: 600}
+ m_ScreenMatchMode: 0
+ m_MatchWidthOrHeight: 0
+ m_PhysicalUnit: 3
+ m_FallbackScreenDPI: 96
+ m_DefaultSpriteDPI: 96
+ m_DynamicPixelsPerUnit: 1
+--- !u!223 &871349755
+Canvas:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 871349752}
+ m_Enabled: 1
+ serializedVersion: 3
+ m_RenderMode: 0
+ m_Camera: {fileID: 0}
+ m_PlaneDistance: 100
+ m_PixelPerfect: 0
+ m_ReceivesEvents: 1
+ m_OverrideSorting: 0
+ m_OverridePixelPerfect: 0
+ m_SortingBucketNormalizedSize: 0
+ m_AdditionalShaderChannelsFlag: 0
+ m_SortingLayerID: 0
+ m_SortingOrder: 0
+ m_TargetDisplay: 0
+--- !u!224 &871349756
+RectTransform:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 871349752}
+ m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
+ m_LocalPosition: {x: 0, y: 0, z: 0}
+ m_LocalScale: {x: 0, y: 0, z: 0}
+ m_Children:
+ - {fileID: 1726294324}
+ m_Father: {fileID: 0}
+ m_RootOrder: 1
+ m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
+ m_AnchorMin: {x: 0, y: 0}
+ m_AnchorMax: {x: 0, y: 0}
+ m_AnchoredPosition: {x: 0, y: 0}
+ m_SizeDelta: {x: 0, y: 0}
+ m_Pivot: {x: 0, y: 0}
--- !u!1 &904015943
GameObject:
m_ObjectHideFlags: 0
@@ -240,3 +334,144 @@ MonoBehaviour:
- 1
- 3
- 7
+ inferenceText: {fileID: 1726294325}
+--- !u!1 &1726294323
+GameObject:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ serializedVersion: 5
+ m_Component:
+ - component: {fileID: 1726294324}
+ - component: {fileID: 1726294326}
+ - component: {fileID: 1726294325}
+ m_Layer: 5
+ m_Name: InferenceText
+ m_TagString: Untagged
+ m_Icon: {fileID: 0}
+ m_NavMeshLayer: 0
+ m_StaticEditorFlags: 0
+ m_IsActive: 1
+--- !u!224 &1726294324
+RectTransform:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 1726294323}
+ m_LocalRotation: {x: -0, y: -0, z: -0, w: 1}
+ m_LocalPosition: {x: 0, y: 0, z: 0}
+ m_LocalScale: {x: 1, y: 1, z: 1}
+ m_Children: []
+ m_Father: {fileID: 871349756}
+ m_RootOrder: 0
+ m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
+ m_AnchorMin: {x: 0.5, y: 0.5}
+ m_AnchorMax: {x: 0.5, y: 0.5}
+ m_AnchoredPosition: {x: 0, y: 25}
+ m_SizeDelta: {x: 450, y: 250}
+ m_Pivot: {x: 0.5, y: 0.5}
+--- !u!114 &1726294325
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 1726294323}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 708705254, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ m_Material: {fileID: 0}
+ m_Color: {r: 0.9338235, g: 0.9338235, b: 0.9338235, a: 1}
+ m_RaycastTarget: 1
+ m_OnCullStateChanged:
+ m_PersistentCalls:
+ m_Calls: []
+ m_TypeName: UnityEngine.UI.MaskableGraphic+CullStateChangedEvent, UnityEngine.UI,
+ Version=1.0.0.0, Culture=neutral, PublicKeyToken=null
+ m_FontData:
+ m_Font: {fileID: 10102, guid: 0000000000000000e000000000000000, type: 0}
+ m_FontSize: 35
+ m_FontStyle: 0
+ m_BestFit: 0
+ m_MinSize: 2
+ m_MaxSize: 40
+ m_Alignment: 4
+ m_AlignByGeometry: 0
+ m_RichText: 1
+ m_HorizontalOverflow: 0
+ m_VerticalOverflow: 0
+ m_LineSpacing: 1
+ m_Text: 'Inference took 0.0153 ms
+
+ Input: 1,3,7
+
+ Output: 3,9,21'
+--- !u!222 &1726294326
+CanvasRenderer:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 1726294323}
+--- !u!1 &2026426602
+GameObject:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ serializedVersion: 5
+ m_Component:
+ - component: {fileID: 2026426605}
+ - component: {fileID: 2026426604}
+ - component: {fileID: 2026426603}
+ m_Layer: 0
+ m_Name: EventSystem
+ m_TagString: Untagged
+ m_Icon: {fileID: 0}
+ m_NavMeshLayer: 0
+ m_StaticEditorFlags: 0
+ m_IsActive: 1
+--- !u!114 &2026426603
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 2026426602}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 1077351063, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ m_HorizontalAxis: Horizontal
+ m_VerticalAxis: Vertical
+ m_SubmitButton: Submit
+ m_CancelButton: Cancel
+ m_InputActionsPerSecond: 10
+ m_RepeatDelay: 0.5
+ m_ForceModuleActive: 0
+--- !u!114 &2026426604
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 2026426602}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: -619905303, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ m_FirstSelected: {fileID: 0}
+ m_sendNavigationEvents: 1
+ m_DragThreshold: 5
+--- !u!4 &2026426605
+Transform:
+ m_ObjectHideFlags: 0
+ m_PrefabParentObject: {fileID: 0}
+ m_PrefabInternal: {fileID: 0}
+ m_GameObject: {fileID: 2026426602}
+ m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
+ m_LocalPosition: {x: 0, y: 0, z: 0}
+ m_LocalScale: {x: 1, y: 1, z: 1}
+ m_Children: []
+ m_Father: {fileID: 0}
+ m_RootOrder: 2
+ m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
index abca814499..83291e6179 100644
--- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs
@@ -18,6 +18,7 @@ using System.Collections.Generic;
using System.Linq;
using TensorFlowLite;
using UnityEngine;
+using UnityEngine.UI;
/// <summary>
/// Simple example demonstrating use of the experimental C# bindings for TensorFlowLite.
@@ -30,14 +31,24 @@ public class HelloTFLite : MonoBehaviour {
[Tooltip("Configurable TFLite input tensor data.")]
public float[] inputs;
+ [Tooltip("Target Text widget for display of inference execution.")]
+ public Text inferenceText;
+
private Interpreter interpreter;
private float[] outputs;
+ void Awake() {
+ // As the demo is extremely simple, there's no need to run at full frame-rate.
+ QualitySettings.vSyncCount = 0;
+ Application.targetFrameRate = 5;
+ }
+
void Start () {
interpreter = new Interpreter(model.bytes);
- Debug.LogFormat("InputCount: {0}, OutputCount: {1}",
- interpreter.GetInputTensorCount(),
- interpreter.GetOutputTensorCount());
+ Debug.LogFormat(
+ "InputCount: {0}, OutputCount: {1}",
+ interpreter.GetInputTensorCount(),
+ interpreter.GetOutputTensorCount());
}
void Update () {
@@ -51,13 +62,17 @@ public class HelloTFLite : MonoBehaviour {
outputs = new float[inputs.Length];
}
+ float startTimeSeconds = Time.realtimeSinceStartup;
interpreter.SetInputTensorData(0, inputs);
interpreter.Invoke();
interpreter.GetOutputTensorData(0, outputs);
+ float inferenceTimeSeconds = Time.realtimeSinceStartup - startTimeSeconds;
- Debug.LogFormat("Input: {0}, Output: {1}",
- ArrayToString(inputs),
- ArrayToString(outputs));
+ inferenceText.text = string.Format(
+ "Inference took {0:0.0000} ms\nInput(s): {1}\nOutput(s): {2}",
+ inferenceTimeSeconds * 1000.0,
+ ArrayToString(inputs),
+ ArrayToString(outputs));
}
void OnDestroy() {
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset
index 74d7b532b0..a9bbfb02d1 100644
--- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset
@@ -35,6 +35,9 @@ GraphicsSettings:
- {fileID: 15106, guid: 0000000000000000f000000000000000, type: 0}
- {fileID: 10753, guid: 0000000000000000f000000000000000, type: 0}
- {fileID: 10770, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 17000, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 16000, guid: 0000000000000000f000000000000000, type: 0}
+ - {fileID: 16002, guid: 0000000000000000f000000000000000, type: 0}
m_PreloadedShaders: []
m_SpritesDefaultMaterial: {fileID: 10754, guid: 0000000000000000f000000000000000,
type: 0}
diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
index 0b3813fccb..c0dcb090b4 100644
--- a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
+++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md
@@ -22,3 +22,6 @@ bazel build -c opt --cxxopt=--std=c++11 \
--cpu=armeabi-v7a \
//tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so
```
+
+If you encounter issues with native plugin discovery on Mac ("Darwin")
+platforms, try renaming `libtensorflowlite_c.so` to `tensorflowlite_c.bundle`.
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD
new file mode 100644
index 0000000000..9c06c4ebd9
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/BUILD
@@ -0,0 +1,84 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# ctc support classes imported directly from TensorFlow.
+cc_library(
+ name = "ctc_utils",
+ hdrs = [
+ "ctc_beam_entry.h",
+ "ctc_beam_scorer.h",
+ "ctc_beam_search.h",
+ "ctc_decoder.h",
+ "ctc_loss_util.h",
+ ],
+ deps = [
+ ":top_n",
+ "//tensorflow/contrib/lite/kernels/internal:types",
+ "//third_party/eigen3",
+ ],
+)
+
+# top_n support classes imported directly from TensorFlow.
+cc_library(
+ name = "top_n",
+ hdrs = [
+ "top_n.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/kernels/internal:types",
+ ],
+)
+
+cc_library(
+ name = "experimental_ops",
+ srcs = [
+ "ctc_beam_search_decoder.cc",
+ ],
+ # Suppress warnings that are introduced by Eigen Tensor.
+ copts = tflite_copts() + [
+ "-Wno-error=reorder",
+ ] + select({
+ "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
+ "//conditions:default": [
+ ],
+ }),
+ deps = [
+ ":ctc_utils",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:gemm_support",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
+ "//tensorflow/contrib/lite/kernels/internal:optimized",
+ "//tensorflow/contrib/lite/kernels/internal:optimized_base",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "ctc_beam_search_decoder_test",
+ size = "small",
+ srcs = ["ctc_beam_search_decoder_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":experimental_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h
new file mode 100644
index 0000000000..a60ff2a1c5
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h
@@ -0,0 +1,150 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Copied from tensorflow/core/util/ctc/ctc_beam_entry.h
+// TODO(b/111524997): Remove this file.
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_
+
+#include <algorithm>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h"
+
+namespace tflite {
+namespace experimental {
+namespace ctc {
+
+// The ctc_beam_search namespace holds several classes meant to be accessed only
+// in case of extending the CTCBeamSearch decoder to allow custom scoring
+// functions.
+//
+// BeamEntry is exposed through template arguments BeamScorer and BeamComparer
+// of CTCBeamSearch (ctc_beam_search.h).
+namespace ctc_beam_search {
+
+struct EmptyBeamState {};
+
+struct BeamProbability {
+ BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {}
+ void Reset() {
+ total = kLogZero;
+ blank = kLogZero;
+ label = kLogZero;
+ }
+ float total;
+ float blank;
+ float label;
+};
+
+template <class CTCBeamState>
+class BeamRoot;
+
+template <class CTCBeamState = EmptyBeamState>
+struct BeamEntry {
+ // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
+ friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
+ BeamEntry<CTCBeamState>* p, int l);
+ inline bool Active() const { return newp.total != kLogZero; }
+ // Return the child at the given index, or construct a new one in-place if
+ // none was found.
+ BeamEntry& GetChild(int ind) {
+ auto entry = children.emplace(ind, nullptr);
+ auto& child_entry = entry.first->second;
+ // If this is a new child, populate the BeamEntry<CTCBeamState>*.
+ if (entry.second) {
+ child_entry = beam_root->AddEntry(this, ind);
+ }
+ return *child_entry;
+ }
+ std::vector<int> LabelSeq(bool merge_repeated) const {
+ std::vector<int> labels;
+ int prev_label = -1;
+ const BeamEntry* c = this;
+ while (c->parent != nullptr) { // Checking c->parent to skip root leaf.
+ if (!merge_repeated || c->label != prev_label) {
+ labels.push_back(c->label);
+ }
+ prev_label = c->label;
+ c = c->parent;
+ }
+ std::reverse(labels.begin(), labels.end());
+ return labels;
+ }
+
+ BeamEntry<CTCBeamState>* parent;
+ int label;
+ // All instances of child BeamEntry are owned by *beam_root.
+ std::unordered_map<int, BeamEntry<CTCBeamState>*> children;
+ BeamProbability oldp;
+ BeamProbability newp;
+ CTCBeamState state;
+
+ private:
+ // Constructor giving parent, label, and the beam_root.
+ // The object pointed to by p cannot be copied and should not be moved,
+ // otherwise parent will become invalid.
+ // This private constructor is only called through the factory method
+ // BeamRoot<CTCBeamState>::AddEntry().
+ BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
+ : parent(p), label(l), beam_root(beam_root) {}
+ BeamRoot<CTCBeamState>* beam_root;
+
+ BeamEntry(const BeamEntry&) = delete;
+ void operator=(const BeamEntry&) = delete;
+};
+
+// This class owns all instances of BeamEntry. This is used to avoid recursive
+// destructor call during destruction.
+template <class CTCBeamState = EmptyBeamState>
+class BeamRoot {
+ public:
+ BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
+ BeamRoot(const BeamRoot&) = delete;
+ BeamRoot& operator=(const BeamRoot&) = delete;
+
+ BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
+ auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
+ beam_entries_.emplace_back(new_entry);
+ return new_entry;
+ }
+ BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
+
+ private:
+ BeamEntry<CTCBeamState>* root_entry_ = nullptr;
+ std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
+};
+
+// BeamComparer is the default beam comparer provided in CTCBeamSearch.
+template <class CTCBeamState = EmptyBeamState>
+class BeamComparer {
+ public:
+ virtual ~BeamComparer() {}
+ virtual bool inline operator()(const BeamEntry<CTCBeamState>* a,
+ const BeamEntry<CTCBeamState>* b) const {
+ return a->newp.total > b->newp.total;
+ }
+};
+
+} // namespace ctc_beam_search
+
+} // namespace ctc
+} // namespace experimental
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h
new file mode 100644
index 0000000000..ec60e26257
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Collection of scoring classes that can be extended and provided to the
+// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a
+// language model).
+//
+// To build a custom scorer extend and implement the pure virtual methods from
+// BeamScorerInterface. The default CTC decoding behavior is implemented
+// through BaseBeamScorer.
+
+// Copied from tensorflow/core/util/ctc/ctc_beam_scorer.h
+// TODO(b/111524997): Remove this file.
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_
+
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h"
+
+namespace tflite {
+namespace experimental {
+namespace ctc {
+
+// Base implementation of a beam scorer used by default by the decoder that can
+// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex
+// scoring is required. Its main purpose is to provide a thin layer for
+// integrating language model scoring easily.
+template <typename CTCBeamState>
+class BaseBeamScorer {
+ public:
+ virtual ~BaseBeamScorer() {}
+ // State initialization.
+ virtual void InitializeState(CTCBeamState* root) const {}
+ // ExpandState is called when expanding a beam to one of its children.
+ // Called at most once per child beam. In the simplest case, no state
+ // expansion is done.
+ virtual void ExpandState(const CTCBeamState& from_state, int from_label,
+ CTCBeamState* to_state, int to_label) const {}
+ // ExpandStateEnd is called after decoding has finished. Its purpose is to
+ // allow a final scoring of the beam in its current state, before resorting
+ // and retrieving the TopN requested candidates. Called at most once per beam.
+ virtual void ExpandStateEnd(CTCBeamState* state) const {}
+ // GetStateExpansionScore should be an inexpensive method to retrieve the
+ // (cached) expansion score computed within ExpandState. The score is
+ // multiplied (log-addition) with the input score at the current step from
+ // the network.
+ //
+ // The score returned should be a log-probability. In the simplest case, as
+ // there's no state expansion logic, the expansion score is zero.
+ virtual float GetStateExpansionScore(const CTCBeamState& state,
+ float previous_score) const {
+ return previous_score;
+ }
+ // GetStateEndExpansionScore should be an inexpensive method to retrieve the
+ // (cached) expansion score computed within ExpandStateEnd. The score is
+ // multiplied (log-addition) with the final probability of the beam.
+ //
+ // The score returned should be a log-probability.
+ virtual float GetStateEndExpansionScore(const CTCBeamState& state) const {
+ return 0;
+ }
+};
+
+} // namespace ctc
+} // namespace experimental
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
new file mode 100644
index 0000000000..c658e43092
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
@@ -0,0 +1,420 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Copied from tensorflow/core/util/ctc/ctc_beam_search.h
+// TODO(b/111524997): Remove this file.
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h"
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h"
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h"
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h"
+#include "tensorflow/contrib/lite/experimental/kernels/top_n.h"
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+
+namespace tflite {
+namespace experimental {
+namespace ctc {
+
+template <typename CTCBeamState = ctc_beam_search::EmptyBeamState,
+ typename CTCBeamComparer =
+ ctc_beam_search::BeamComparer<CTCBeamState>>
+class CTCBeamSearchDecoder : public CTCDecoder {
+ // Beam Search
+ //
+ // Example (GravesTh Fig. 7.5):
+ // a -
+ // P = [ 0.3 0.7 ] t = 0
+ // [ 0.4 0.6 ] t = 1
+ //
+ // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42
+ // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58
+ //
+ // In this case, Best Path decoding is suboptimal.
+ //
+ // For Beam Search, we use the following main recurrence relations:
+ //
+ // Relation 1:
+ // ---------------------------------------------------------- Eq. 1
+ // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7)
+ // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7))
+ // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and
+ // updated recursively in the beam entry.
+ //
+ // Relation 2:
+ // ---------------------------------------------------------- Eq. 2
+ // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3)
+ // for ? in a, b, d, ..., (not including c or the blank index),
+ // and the recurrence starts from the beam entry for P(l=abc @ t=2).
+ //
+ // For this case, the length of the new sequence equals t+1 (t
+ // starts at 0). This special case can be calculated as:
+ // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
+ // but we calculate it recursively for speed purposes.
+ typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
+ typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
+ typedef ctc_beam_search::BeamProbability BeamProbability;
+
+ public:
+ typedef BaseBeamScorer<CTCBeamState> DefaultBeamScorer;
+
+ // The beam search decoder is constructed specifying the beam_width (number of
+ // candidates to keep at each decoding timestep) and a beam scorer (used for
+ // custom scoring, for example enabling the use of a language model).
+ // The ownership of the scorer remains with the caller. The default
+ // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the
+ // standard beam search.
+ CTCBeamSearchDecoder(int num_classes, int beam_width,
+ BaseBeamScorer<CTCBeamState>* scorer, int batch_size = 1,
+ bool merge_repeated = false)
+ : CTCDecoder(num_classes, batch_size, merge_repeated),
+ beam_width_(beam_width),
+ leaves_(beam_width),
+ beam_scorer_(scorer) {
+ Reset();
+ }
+
+ ~CTCBeamSearchDecoder() override {}
+
+ // Run the hibernating beam search algorithm on the given input.
+ bool Decode(const CTCDecoder::SequenceLength& seq_len,
+ const std::vector<CTCDecoder::Input>& input,
+ std::vector<CTCDecoder::Output>* output,
+ CTCDecoder::ScoreOutput* scores) override;
+
+ // Calculate the next step of the beam search and update the internal state.
+ template <typename Vector>
+ void Step(const Vector& log_input_t);
+
+ template <typename Vector>
+ float GetTopK(const int K, const Vector& input,
+ std::vector<float>* top_k_logits,
+ std::vector<int>* top_k_indices);
+
+ // Retrieve the beam scorer instance used during decoding.
+ BaseBeamScorer<CTCBeamState>* GetBeamScorer() const { return beam_scorer_; }
+
+ // Set label selection parameters for faster decoding.
+ // See comments for label_selection_size_ and label_selection_margin_.
+ void SetLabelSelectionParameters(int label_selection_size,
+ float label_selection_margin) {
+ label_selection_size_ = label_selection_size;
+ label_selection_margin_ = label_selection_margin;
+ }
+
+ // Reset the beam search
+ void Reset();
+
+ // Extract the top n paths at current time step
+ bool TopPaths(int n, std::vector<std::vector<int>>* paths,
+ std::vector<float>* log_probs, bool merge_repeated) const;
+
+ private:
+ int beam_width_;
+
+ // Label selection is designed to avoid possibly very expensive scorer calls,
+ // by pruning the hypotheses based on the input alone.
+ // Label selection size controls how many items in each beam are passed
+ // through to the beam scorer. Only items with top N input scores are
+ // considered.
+ // Label selection margin controls the difference between minimal input score
+ // (versus the best scoring label) for an item to be passed to the beam
+ // scorer. This margin is expressed in terms of log-probability.
+ // Default is to do no label selection.
+ // For more detail: https://research.google.com/pubs/pub44823.html
+ int label_selection_size_ = 0; // zero means unlimited
+ float label_selection_margin_ = -1; // -1 means unlimited.
+
+ gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
+ std::unique_ptr<BeamRoot> beam_root_;
+ BaseBeamScorer<CTCBeamState>* beam_scorer_;
+
+ CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete;
+ void operator=(const CTCBeamSearchDecoder&) = delete;
+};
+
+template <typename CTCBeamState, typename CTCBeamComparer>
+bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Decode(
+ const CTCDecoder::SequenceLength& seq_len,
+ const std::vector<CTCDecoder::Input>& input,
+ std::vector<CTCDecoder::Output>* output, ScoreOutput* scores) {
+ // Storage for top paths.
+ std::vector<std::vector<int>> beams;
+ std::vector<float> beam_log_probabilities;
+ int top_n = output->size();
+ if (std::any_of(output->begin(), output->end(),
+ [this](const CTCDecoder::Output& output) -> bool {
+ return output.size() < this->batch_size_;
+ })) {
+ return false;
+ }
+ if (scores->rows() < batch_size_ || scores->cols() < top_n) {
+ return false;
+ }
+
+ for (int b = 0; b < batch_size_; ++b) {
+ int seq_len_b = seq_len[b];
+ Reset();
+
+ for (int t = 0; t < seq_len_b; ++t) {
+ // Pass log-probabilities for this example + time.
+ Step(input[t].row(b));
+ } // for (int t...
+
+ // O(n * log(n))
+ std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
+ leaves_.Reset();
+ for (int i = 0; i < branches->size(); ++i) {
+ BeamEntry* entry = (*branches)[i];
+ beam_scorer_->ExpandStateEnd(&entry->state);
+ entry->newp.total +=
+ beam_scorer_->GetStateEndExpansionScore(entry->state);
+ leaves_.push(entry);
+ }
+
+ bool status =
+ TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_);
+ if (!status) {
+ return status;
+ }
+
+ TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size());
+ TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size());
+
+ for (int i = 0; i < top_n; ++i) {
+ // Copy output to the correct beam + batch
+ (*output)[i][b].swap(beams[i]);
+ (*scores)(b, i) = -beam_log_probabilities[i];
+ }
+ } // for (int b...
+ return true;
+}
+
+template <typename CTCBeamState, typename CTCBeamComparer>
+template <typename Vector>
+float CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::GetTopK(
+ const int K, const Vector& input, std::vector<float>* top_k_logits,
+ std::vector<int>* top_k_indices) {
+ // Find Top K choices, complexity nk in worst case. The array input is read
+ // just once.
+ TFLITE_DCHECK_EQ(num_classes_, input.size());
+ top_k_logits->clear();
+ top_k_indices->clear();
+ top_k_logits->resize(K, -INFINITY);
+ top_k_indices->resize(K, -1);
+ for (int j = 0; j < num_classes_ - 1; ++j) {
+ const float logit = input(j);
+ if (logit > (*top_k_logits)[K - 1]) {
+ int k = K - 1;
+ while (k > 0 && logit > (*top_k_logits)[k - 1]) {
+ (*top_k_logits)[k] = (*top_k_logits)[k - 1];
+ (*top_k_indices)[k] = (*top_k_indices)[k - 1];
+ k--;
+ }
+ (*top_k_logits)[k] = logit;
+ (*top_k_indices)[k] = j;
+ }
+ }
+ // Return max value which is in 0th index or blank character logit
+ return std::max((*top_k_logits)[0], input(num_classes_ - 1));
+}
+
+template <typename CTCBeamState, typename CTCBeamComparer>
+template <typename Vector>
+void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
+ const Vector& raw_input) {
+ std::vector<float> top_k_logits;
+ std::vector<int> top_k_indices;
+ const bool top_k =
+ (label_selection_size_ > 0 && label_selection_size_ < raw_input.size());
+ // Number of character classes to consider in each step.
+ const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1);
+ // Get max coefficient and remove it from raw_input later.
+ float max_coeff;
+ if (top_k) {
+ max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits,
+ &top_k_indices);
+ } else {
+ max_coeff = raw_input.maxCoeff();
+ }
+ const float label_selection_input_min =
+ (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
+ : -std::numeric_limits<float>::infinity();
+
+ // Extract the beams sorted in decreasing new probability
+ TFLITE_DCHECK_EQ(num_classes_, raw_input.size());
+
+ std::unique_ptr<std::vector<BeamEntry*>> branches(leaves_.Extract());
+ leaves_.Reset();
+
+ for (BeamEntry* b : *branches) {
+ // P(.. @ t) becomes the new P(.. @ t-1)
+ b->oldp = b->newp;
+ }
+
+ for (BeamEntry* b : *branches) {
+ if (b->parent != nullptr) { // if not the root
+ if (b->parent->Active()) {
+ // If last two sequence characters are identical:
+ // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5)
+ // + Pblank(l=ac @ t=5))
+ // else:
+ // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5)
+ // + P(l=ab @ t=5))
+ float previous = (b->label == b->parent->label) ? b->parent->oldp.blank
+ : b->parent->oldp.total;
+ b->newp.label =
+ LogSumExp(b->newp.label,
+ beam_scorer_->GetStateExpansionScore(b->state, previous));
+ }
+ // Plabel(l=abc @ t=6) *= P(c @ 6)
+ b->newp.label += raw_input(b->label) - max_coeff;
+ }
+ // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
+ b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
+
+ // Push the entry back to the top paths list.
+ // Note, this will always fill leaves back up in sorted order.
+ leaves_.push(b);
+ }
+
+ // we need to resort branches in descending oldp order.
+
+ // branches is in descending oldp order because it was
+ // originally in descending newp order and we copied newp to oldp.
+
+ // Grow new leaves
+ for (BeamEntry* b : *branches) {
+ // A new leaf (represented by its BeamProbability) is a candidate
+ // iff its total probability is nonzero and either the beam list
+ // isn't full, or the lowest probability entry in the beam has a
+ // lower probability than the leaf.
+ auto is_candidate = [this](const BeamProbability& prob) {
+ return (prob.total > kLogZero &&
+ (leaves_.size() < beam_width_ ||
+ prob.total > leaves_.peek_bottom()->newp.total));
+ };
+
+ if (!is_candidate(b->oldp)) {
+ continue;
+ }
+
+ for (int ind = 0; ind < max_classes; ind++) {
+ const int label = top_k ? top_k_indices[ind] : ind;
+ const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
+ // Perform label selection: if input for this label looks very
+ // unpromising, never evaluate it with a scorer.
+ if (logit < label_selection_input_min) {
+ continue;
+ }
+ BeamEntry& c = b->GetChild(label);
+ if (!c.Active()) {
+ // Pblank(l=abcd @ t=6) = 0
+ c.newp.blank = kLogZero;
+ // If new child label is identical to beam label:
+ // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6)
+ // Otherwise:
+ // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
+ beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
+ float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
+ c.newp.label = logit - max_coeff +
+ beam_scorer_->GetStateExpansionScore(c.state, previous);
+ // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
+ c.newp.total = c.newp.label;
+
+ if (is_candidate(c.newp)) {
+ // Before adding the new node to the beam, check if the beam
+ // is already at maximum width.
+ if (leaves_.size() == beam_width_) {
+ // Bottom is no longer in the beam search. Reset
+ // its probability; signal it's no longer in the beam search.
+ BeamEntry* bottom = leaves_.peek_bottom();
+ bottom->newp.Reset();
+ }
+ leaves_.push(&c);
+ } else {
+ // Deactivate child.
+ c.oldp.Reset();
+ c.newp.Reset();
+ }
+ }
+ }
+ } // for (BeamEntry* b...
+}
+
+template <typename CTCBeamState, typename CTCBeamComparer>
+void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
+ leaves_.Reset();
+
+ // This beam root, and all of its children, will be in memory until
+ // the next reset.
+ beam_root_.reset(new BeamRoot(nullptr, -1));
+ beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
+ beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
+
+ // Add the root as the initial leaf.
+ leaves_.push(beam_root_->RootEntry());
+
+ // Call initialize state on the root object.
+ beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
+}
+
+template <typename CTCBeamState, typename CTCBeamComparer>
+bool CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
+ int n, std::vector<std::vector<int>>* paths, std::vector<float>* log_probs,
+ bool merge_repeated) const {
+ TFLITE_DCHECK(paths);
+ TFLITE_DCHECK(log_probs);
+ paths->clear();
+ log_probs->clear();
+ if (n > beam_width_) {
+ return false;
+ }
+ if (n > leaves_.size()) {
+ return false;
+ }
+
+ gtl::TopN<BeamEntry*, CTCBeamComparer> top_branches(n);
+
+ // O(beam_width_ * log(n)), space complexity is O(n)
+ for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) {
+ top_branches.push(*it);
+ }
+ // O(n * log(n))
+ std::unique_ptr<std::vector<BeamEntry*>> branches(top_branches.Extract());
+
+ for (int i = 0; i < n; ++i) {
+ BeamEntry* e((*branches)[i]);
+ paths->push_back(e->LabelSeq(merge_repeated));
+ log_probs->push_back(e->newp.total);
+ }
+ return true;
+}
+
+} // namespace ctc
+} // namespace experimental
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
new file mode 100644
index 0000000000..834d1ebd66
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -0,0 +1,247 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace experimental {
+namespace ctc_beam_search_decoder {
+
+constexpr int kInputsTensor = 0;
+constexpr int kSequenceLengthTensor = 1;
+
+typedef struct {
+ int beam_width;
+ int top_paths;
+ bool merge_repeated;
+} CTCBeamSearchDecoderParams;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ TFLITE_CHECK(buffer != nullptr);
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+
+ CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams;
+ option->beam_width = m["beam_width"].AsInt32();
+ option->top_paths = m["top_paths"].AsInt32();
+ option->merge_repeated = m["merge_repeated"].AsBool();
+
+ return option;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<CTCBeamSearchDecoderParams*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const CTCBeamSearchDecoderParams* option =
+ reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
+ const int top_paths = option->top_paths;
+ TF_LITE_ENSURE(context, option->beam_width >= top_paths);
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ // The outputs should be top_paths * 3 + 1.
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1);
+
+ const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3);
+ // TensorFlow only supports float.
+ TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32);
+ const int batch_size = SizeOfDimension(inputs, 1);
+
+ const TfLiteTensor* sequence_length =
+ GetInput(context, node, kSequenceLengthTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size);
+ // TensorFlow only supports int32.
+ TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32);
+
+ // Resize decoded outputs.
+ // Do not resize indices & values cause we don't know the values yet.
+ for (int i = 0; i < top_paths; ++i) {
+ TfLiteTensor* indices = GetOutput(context, node, i);
+ SetTensorToDynamic(indices);
+ TfLiteTensor* values = GetOutput(context, node, i + top_paths);
+ SetTensorToDynamic(values);
+ TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths);
+ SetTensorToDynamic(output_shape);
+ }
+
+ // Resize log probability outputs.
+ TfLiteTensor* log_probability_output =
+ GetOutput(context, node, top_paths * 3);
+ TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2);
+ log_probability_output_shape_array->data[0] = batch_size;
+ log_probability_output_shape_array->data[1] = top_paths;
+ return context->ResizeTensor(context, log_probability_output,
+ log_probability_output_shape_array);
+}
+
+TfLiteStatus Resize(TfLiteContext* context,
+ std::initializer_list<int32_t> output_shape,
+ TfLiteTensor* output) {
+ const int dimensions = output_shape.size();
+ TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions);
+ int i = 0;
+ for (const int v : output_shape) {
+ output_shape_array->data[i++] = v;
+ }
+ return context->ResizeTensor(context, output, output_shape_array);
+}
+
+TfLiteStatus StoreAllDecodedSequences(
+ TfLiteContext* context,
+ const std::vector<std::vector<std::vector<int>>>& sequences,
+ TfLiteNode* node, int top_paths) {
+ const int32_t batch_size = sequences.size();
+ std::vector<int32_t> num_entries(top_paths, 0);
+
+ // Calculate num_entries per path
+ for (const auto& batch_s : sequences) {
+ TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths);
+ for (int p = 0; p < top_paths; ++p) {
+ num_entries[p] += batch_s[p].size();
+ }
+ }
+
+ for (int p = 0; p < top_paths; ++p) {
+ const int32_t p_num = num_entries[p];
+
+ // Resize the decoded outputs.
+ TfLiteTensor* indices = GetOutput(context, node, p);
+ TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices));
+
+ TfLiteTensor* values = GetOutput(context, node, p + top_paths);
+ TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values));
+
+ TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths);
+ TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape));
+
+ int32_t max_decoded = 0;
+ int32_t offset = 0;
+
+ int32_t* indices_data = GetTensorData<int32_t>(indices);
+ int32_t* values_data = GetTensorData<int32_t>(values);
+ int32_t* decoded_shape_data = GetTensorData<int32_t>(decoded_shape);
+ for (int b = 0; b < batch_size; ++b) {
+ auto& p_batch = sequences[b][p];
+ int32_t num_decoded = p_batch.size();
+ max_decoded = std::max(max_decoded, num_decoded);
+
+ std::copy_n(p_batch.begin(), num_decoded, values_data + offset);
+ for (int32_t t = 0; t < num_decoded; ++t, ++offset) {
+ indices_data[offset * 2] = b;
+ indices_data[offset * 2 + 1] = t;
+ }
+ }
+
+ decoded_shape_data[0] = batch_size;
+ decoded_shape_data[1] = max_decoded;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor);
+ const TfLiteTensor* sequence_length =
+ GetInput(context, node, kSequenceLengthTensor);
+ const CTCBeamSearchDecoderParams* option =
+ reinterpret_cast<CTCBeamSearchDecoderParams*>(node->user_data);
+
+ const int max_time = SizeOfDimension(inputs, 0);
+ const int batch_size = SizeOfDimension(inputs, 1);
+ const int num_classes = SizeOfDimension(inputs, 2);
+
+ const int beam_width = option->beam_width;
+ const int top_paths = option->top_paths;
+ const bool merge_repeated = option->merge_repeated;
+
+ // Validate sequence length is less or equal than max time.
+ for (int i = 0; i < batch_size; ++i) {
+ TF_LITE_ENSURE(context,
+ max_time >= GetTensorData<int32_t>(sequence_length)[i]);
+ }
+
+ // The following logic is implemented like
+ // tensorflow/core/kernels/ctc_decoder_ops.cc
+ std::vector<optimized_ops::TTypes<float>::UnalignedConstMatrix> input_list_t;
+
+ for (std::size_t t = 0; t < max_time; ++t) {
+ input_list_t.emplace_back(
+ GetTensorData<float>(inputs) + t * batch_size * num_classes, batch_size,
+ num_classes);
+ }
+
+ ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer
+ beam_scorer;
+ ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search(
+ num_classes, beam_width, &beam_scorer, 1 /* batch_size */,
+ merge_repeated);
+
+ // Allocate temporary memory for holding chip operation data.
+ float* input_chip_t_data =
+ static_cast<float*>(malloc(num_classes * sizeof(float)));
+ Eigen::array<Eigen::DenseIndex, 1> dims;
+ dims[0] = num_classes;
+ optimized_ops::TTypes<float>::Flat input_chip_t(input_chip_t_data, dims);
+
+ std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
+ std::vector<float> log_probs;
+
+ TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths);
+ float* log_probabilities_output = GetTensorData<float>(log_probabilities);
+
+ // Assumption: the blank index is num_classes - 1
+ for (int b = 0; b < batch_size; ++b) {
+ auto& best_paths_b = best_paths[b];
+ best_paths_b.resize(top_paths);
+ for (int t = 0; t < GetTensorData<int32_t>(sequence_length)[b]; ++t) {
+ input_chip_t = input_list_t[t].chip(b, 0);
+ auto input_bi =
+ Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
+ beam_search.Step(input_bi);
+ }
+ TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b,
+ &log_probs, merge_repeated));
+ beam_search.Reset();
+
+ // Fill in log_probabilities output.
+ for (int bp = 0; bp < top_paths; ++bp) {
+ log_probabilities_output[b * top_paths + bp] = log_probs[bp];
+ }
+ }
+
+ free(input_chip_t_data);
+ return StoreAllDecodedSequences(context, best_paths, node, top_paths);
+}
+
+} // namespace ctc_beam_search_decoder
+
+TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() {
+ static TfLiteRegistration r = {
+ ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free,
+ ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval};
+ return &r;
+}
+
+} // namespace experimental
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
new file mode 100644
index 0000000000..9d1e6a562f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -0,0 +1,238 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace experimental {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER();
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class CTCBeamSearchDecoderOpModel : public SingleOpModel {
+ public:
+ CTCBeamSearchDecoderOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> sequence_length_shape,
+ int beam_width, int top_paths,
+ bool merge_repeated) {
+ inputs_ = AddInput(TensorType_FLOAT32);
+ sequence_length_ = AddInput(TensorType_INT32);
+
+ for (int i = 0; i < top_paths * 3; ++i) {
+ outputs_.push_back(AddOutput(TensorType_INT32));
+ }
+ outputs_.push_back(AddOutput(TensorType_FLOAT32));
+
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("beam_width", beam_width);
+ fbb.Int("top_paths", top_paths);
+ fbb.Bool("merge_repeated", merge_repeated);
+ });
+ fbb.Finish();
+ SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(),
+ Register_CTC_BEAM_SEARCH_DECODER);
+ BuildInterpreter({input_shape, sequence_length_shape});
+ }
+
+ int inputs() { return inputs_; }
+
+ int sequence_length() { return sequence_length_; }
+
+ std::vector<std::vector<int>> GetDecodedOutpus() {
+ std::vector<std::vector<int>> outputs;
+ for (int i = 0; i < outputs_.size() - 1; ++i) {
+ outputs.push_back(ExtractVector<int>(outputs_[i]));
+ }
+ return outputs;
+ }
+
+ std::vector<float> GetLogProbabilitiesOutput() {
+ return ExtractVector<float>(outputs_[outputs_.size() - 1]);
+ }
+
+ std::vector<std::vector<int>> GetOutputShapes() {
+ std::vector<std::vector<int>> output_shapes;
+ for (const int output : outputs_) {
+ output_shapes.push_back(GetTensorShape(output));
+ }
+ return output_shapes;
+ }
+
+ private:
+ int inputs_;
+ int sequence_length_;
+ std::vector<int> outputs_;
+};
+
+TEST(CTCBeamSearchTest, SimpleTest) {
+ CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true);
+ m.PopulateTensor<float>(m.inputs(),
+ {-0.50922557, -1.35512652, -2.55445064, -1.58419356});
+ m.PopulateTensor<int>(m.sequence_length(), {2});
+ m.Invoke();
+
+ // Make sure the output shapes are right.
+ const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 4);
+ EXPECT_THAT(output_shapes[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(1));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+ EXPECT_THAT(output_shapes[3], ElementsAre(1, 1));
+
+ // Check decoded outputs.
+ const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
+ EXPECT_EQ(decoded_outputs.size(), 3);
+ EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0));
+ EXPECT_THAT(decoded_outputs[1], ElementsAre(0));
+ EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
+ // Check log probabilities output.
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({0.32134813})));
+}
+
+TEST(CTCBeamSearchTest, MultiBatchTest) {
+ CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true);
+ m.PopulateTensor<float>(
+ m.inputs(),
+ {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399,
+ -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619,
+ -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325,
+ -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719,
+ -0.2065197, -0.18725838, -1.42770405, -0.86051965, -1.61642301,
+ -2.07275114, -0.9201845});
+ m.PopulateTensor<int>(m.sequence_length(), {3, 3, 3});
+ m.Invoke();
+
+ // Make sure the output shapes are right.
+ const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 4);
+ EXPECT_THAT(output_shapes[0], ElementsAre(4, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(4));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+ EXPECT_THAT(output_shapes[3], ElementsAre(3, 1));
+
+ // Check decoded outputs.
+ const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
+ EXPECT_EQ(decoded_outputs.size(), 3);
+ EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0));
+ EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
+ EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
+ // Check log probabilities output.
+ EXPECT_THAT(
+ m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572})));
+}
+
+TEST(CTCBeamSearchTest, MultiPathsTest) {
+ CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true);
+ m.PopulateTensor<float>(
+ m.inputs(),
+ {-2.206851, -0.09542714, -0.2393415, -3.81866197, -0.27241158,
+ -0.20371124, -0.68236623, -1.1397166, -0.17422639, -1.85224048,
+ -0.9406037, -0.32544678, -0.21846784, -0.38377237, -0.33498676,
+ -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412,
+ -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225,
+ -0.23920634, -1.2370502, -4.98751576, -3.12995717, -0.43129368});
+ m.PopulateTensor<int>(m.sequence_length(), {3, 3});
+ m.Invoke();
+
+ // Make sure the output shapes are right.
+ const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 7);
+ EXPECT_THAT(output_shapes[0], ElementsAre(4, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3, 2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(4));
+ EXPECT_THAT(output_shapes[3], ElementsAre(3));
+ EXPECT_THAT(output_shapes[4], ElementsAre(2));
+ EXPECT_THAT(output_shapes[5], ElementsAre(2));
+ EXPECT_THAT(output_shapes[6], ElementsAre(2, 2));
+
+ // Check decoded outputs.
+ const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
+ EXPECT_EQ(decoded_outputs.size(), 6);
+ EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1));
+ EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0));
+ EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0));
+ EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0));
+ EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2));
+ EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
+ // Check log probabilities output.
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.91318405, 0.9060272, 1.0780245, 0.64358956})));
+}
+
+TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
+ CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true);
+ m.PopulateTensor<float>(
+ m.inputs(),
+ {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756,
+ -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127,
+ -0.48920721, -1.42635476, -1.3462478, -0.02565498, -0.30179568,
+ -0.6491698, -0.55017719, -2.92291466, -0.92522973, -0.47592022,
+ -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612,
+ -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203,
+ -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522,
+ -0.49786342});
+ m.PopulateTensor<int>(m.sequence_length(), {1, 2, 3});
+ m.Invoke();
+
+ // Make sure the output shapes are right.
+ const std::vector<std::vector<int>>& output_shapes = m.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 4);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+ EXPECT_THAT(output_shapes[3], ElementsAre(3, 1));
+
+ // Check decoded outputs.
+ const std::vector<std::vector<int>>& decoded_outputs = m.GetDecodedOutpus();
+ EXPECT_EQ(decoded_outputs.size(), 3);
+ EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0));
+ EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1));
+ EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
+ // Check log probabilities output.
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005})));
+}
+
+} // namespace
+} // namespace experimental
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h
new file mode 100644
index 0000000000..596ad4a5f7
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h
@@ -0,0 +1,114 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Copied from tensorflow/core/util/ctc/ctc_decoder.h
+// TODO(b/111524997): Remove this file.
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_
+
+#include <memory>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+
+namespace tflite {
+namespace experimental {
+namespace ctc {
+
+// The CTCDecoder is an abstract interface to be implemented when providing a
+// decoding method on the timestep output of a RNN trained with CTC loss.
+//
+// The two types of decoding available are:
+// - greedy path, through the CTCGreedyDecoder
+// - beam search, through the CTCBeamSearchDecoder
+class CTCDecoder {
+ public:
+ typedef Eigen::Map<const Eigen::ArrayXi> SequenceLength;
+ typedef Eigen::Map<const Eigen::MatrixXf> Input;
+ typedef std::vector<std::vector<int>> Output;
+ typedef Eigen::Map<Eigen::MatrixXf> ScoreOutput;
+
+ CTCDecoder(int num_classes, int batch_size, bool merge_repeated)
+ : num_classes_(num_classes),
+ blank_index_(num_classes - 1),
+ batch_size_(batch_size),
+ merge_repeated_(merge_repeated) {}
+
+ virtual ~CTCDecoder() {}
+
+ // Dimensionality of the input/output is expected to be:
+ // - seq_len[b] - b = 0 to batch_size_
+ // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_
+ // - output.size() specifies the number of beams to be returned.
+ // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size()
+ virtual bool Decode(const SequenceLength& seq_len,
+ const std::vector<Input>& input,
+ std::vector<Output>* output, ScoreOutput* scores) = 0;
+
+ int batch_size() { return batch_size_; }
+ int num_classes() { return num_classes_; }
+
+ protected:
+ int num_classes_;
+ int blank_index_;
+ int batch_size_;
+ bool merge_repeated_;
+};
+
+// CTCGreedyDecoder is an implementation of the simple best path decoding
+// algorithm, selecting at each timestep the most likely class at each timestep.
+class CTCGreedyDecoder : public CTCDecoder {
+ public:
+ CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated)
+ : CTCDecoder(num_classes, batch_size, merge_repeated) {}
+
+ bool Decode(const CTCDecoder::SequenceLength& seq_len,
+ const std::vector<CTCDecoder::Input>& input,
+ std::vector<CTCDecoder::Output>* output,
+ CTCDecoder::ScoreOutput* scores) override {
+ if (output->empty() || (*output)[0].size() < batch_size_) {
+ return false;
+ }
+ if (scores->rows() < batch_size_ || scores->cols() == 0) {
+ return false;
+ }
+ // For each batch entry, identify the transitions
+ for (int b = 0; b < batch_size_; ++b) {
+ int seq_len_b = seq_len[b];
+ // Only writing to beam 0
+ std::vector<int>& output_b = (*output)[0][b];
+
+ int prev_class_ix = -1;
+ (*scores)(b, 0) = 0;
+ for (int t = 0; t < seq_len_b; ++t) {
+ auto row = input[t].row(b);
+ int max_class_ix;
+ (*scores)(b, 0) += -row.maxCoeff(&max_class_ix);
+ if (max_class_ix != blank_index_ &&
+ !(merge_repeated_ && max_class_ix == prev_class_ix)) {
+ output_b.push_back(max_class_ix);
+ }
+ prev_class_ix = max_class_ix;
+ }
+ }
+ return true;
+ }
+};
+
+} // namespace ctc
+} // namespace experimental
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h
new file mode 100644
index 0000000000..0bae732533
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Copied from tensorflow/core/util/ctc/ctc_loss_util.h
+// TODO(b/111524997): Remove this file.
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_
+
+#include <cmath>
+#include <limits>
+
+namespace tflite {
+namespace experimental {
+namespace ctc {
+
+const float kLogZero = -std::numeric_limits<float>::infinity();
+
+// Add logarithmic probabilities using:
+// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a)))
+// The two inputs are assumed to be log probabilities.
+// (GravesTh) Eq. 7.18
+inline float LogSumExp(float log_prob_1, float log_prob_2) {
+ // Always have 'b' be the smaller number to avoid the exponential from
+ // blowing up.
+ if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) {
+ return kLogZero;
+ } else {
+ return (log_prob_1 > log_prob_2)
+ ? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1))
+ : log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2));
+ }
+}
+
+} // namespace ctc
+} // namespace experimental
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_
diff --git a/tensorflow/contrib/lite/experimental/kernels/top_n.h b/tensorflow/contrib/lite/experimental/kernels/top_n.h
new file mode 100644
index 0000000000..cd2a2f1c80
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/kernels/top_n.h
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This simple class finds the top n elements of an incrementally provided set
+// of elements which you push one at a time. If the number of elements exceeds
+// n, the lowest elements are incrementally dropped. At the end you get
+// a vector of the top elements sorted in descending order (through Extract() or
+// ExtractNondestructive()), or a vector of the top elements but not sorted
+// (through ExtractUnsorted() or ExtractUnsortedNondestructive()).
+//
+// The value n is specified in the constructor. If there are p elements pushed
+// altogether:
+// The total storage requirements are O(min(n, p)) elements
+// The running time is O(p * log(min(n, p))) comparisons
+// If n is a constant, the total storage required is a constant and the running
+// time is linear in p.
+//
+// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p)
+// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements,
+// discarding the lowest n elements whenever the buffer is full using a linear-
+// time median algorithm. This may have better performance when the input
+// sequence is partially sorted.
+//
+// NOTE(zhifengc): This class should be redesigned to avoid reallocating a
+// vector for each Extract.
+
+// Copied from tensorflow/core/lib/gtl/top_n.h
+// TODO(b/111524997): Remove this file.
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+
+namespace tflite {
+namespace gtl {
+
+// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate,
+// not the more commonly used "less" predicate.
+//
+// If you use a "less" predicate here, the TopN will pick out the bottom N
+// elements out of the ones passed to it, and it will return them sorted in
+// ascending order.
+//
+// TopN is rule-of-zero copyable and movable if its members are.
+template <class T, class Cmp = std::greater<T> >
+class TopN {
+ public:
+ // The TopN is in one of the three states:
+ //
+ // o UNORDERED: this is the state an instance is originally in,
+ // where the elements are completely orderless.
+ //
+ // o BOTTOM_KNOWN: in this state, we keep the invariant that there
+ // is at least one element in it, and the lowest element is at
+ // position 0. The elements in other positions remain
+ // unsorted. This state is reached if the state was originally
+ // UNORDERED and a peek_bottom() function call is invoked.
+ //
+ // o HEAP_SORTED: in this state, the array is kept as a heap and
+ // there are exactly (limit_+1) elements in the array. This
+ // state is reached when at least (limit_+1) elements are
+ // pushed in.
+ //
+ // The state transition graph is at follows:
+ //
+ // peek_bottom() (limit_+1) elements
+ // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED
+ // | ^
+ // | (limit_+1) elements |
+ // +-----------------------------------------------------------+
+
+ enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED };
+ using UnsortedIterator = typename std::vector<T>::const_iterator;
+
+ // 'limit' is the maximum number of top results to return.
+ explicit TopN(size_t limit) : TopN(limit, Cmp()) {}
+ TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {}
+
+ size_t limit() const { return limit_; }
+
+ // Number of elements currently held by this TopN object. This
+ // will be no greater than 'limit' passed to the constructor.
+ size_t size() const { return std::min(elements_.size(), limit_); }
+
+ bool empty() const { return size() == 0; }
+
+ // If you know how many elements you will push at the time you create the
+ // TopN object, you can call reserve to preallocate the memory that TopN
+ // will need to process all 'n' pushes. Calling this method is optional.
+ void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); }
+
+ // Push 'v'. If the maximum number of elements was exceeded, drop the
+ // lowest element and return it in 'dropped' (if given). If the maximum is not
+ // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or
+ // nullptr, in which case it is not filled in.
+ // Requires: T is CopyAssignable, Swappable
+ void push(const T &v) { push(v, nullptr); }
+ void push(const T &v, T *dropped) { PushInternal(v, dropped); }
+
+ // Move overloads of push.
+ // Requires: T is MoveAssignable, Swappable
+ void push(T &&v) { // NOLINT(build/c++11)
+ push(std::move(v), nullptr);
+ }
+ void push(T &&v, T *dropped) { // NOLINT(build/c++11)
+ PushInternal(std::move(v), dropped);
+ }
+
+ // Peeks the bottom result without calling Extract()
+ const T &peek_bottom();
+
+ // Extract the elements as a vector sorted in descending order. The caller
+ // assumes ownership of the vector and must delete it when done. This is a
+ // destructive operation. The only method that can be called immediately
+ // after Extract() is Reset().
+ std::vector<T> *Extract();
+
+ // Similar to Extract(), but makes no guarantees the elements are in sorted
+ // order. As with Extract(), the caller assumes ownership of the vector and
+ // must delete it when done. This is a destructive operation. The only
+ // method that can be called immediately after ExtractUnsorted() is Reset().
+ std::vector<T> *ExtractUnsorted();
+
+ // A non-destructive version of Extract(). Copy the elements in a new vector
+ // sorted in descending order and return it. The caller assumes ownership of
+ // the new vector and must delete it when done. After calling
+ // ExtractNondestructive(), the caller can continue to push() new elements.
+ std::vector<T> *ExtractNondestructive() const;
+
+ // A non-destructive version of Extract(). Copy the elements to a given
+ // vector sorted in descending order. After calling
+ // ExtractNondestructive(), the caller can continue to push() new elements.
+ // Note:
+ // 1. The given argument must to be allocated.
+ // 2. Any data contained in the vector prior to the call will be deleted
+ // from it. After the call the vector will contain only the elements
+ // from the data structure.
+ void ExtractNondestructive(std::vector<T> *output) const;
+
+ // A non-destructive version of ExtractUnsorted(). Copy the elements in a new
+ // vector and return it, with no guarantees the elements are in sorted order.
+ // The caller assumes ownership of the new vector and must delete it when
+ // done. After calling ExtractUnsortedNondestructive(), the caller can
+ // continue to push() new elements.
+ std::vector<T> *ExtractUnsortedNondestructive() const;
+
+ // A non-destructive version of ExtractUnsorted(). Copy the elements into
+ // a given vector, with no guarantees the elements are in sorted order.
+ // After calling ExtractUnsortedNondestructive(), the caller can continue
+ // to push() new elements.
+ // Note:
+ // 1. The given argument must to be allocated.
+ // 2. Any data contained in the vector prior to the call will be deleted
+ // from it. After the call the vector will contain only the elements
+ // from the data structure.
+ void ExtractUnsortedNondestructive(std::vector<T> *output) const;
+
+ // Return an iterator to the beginning (end) of the container,
+ // with no guarantees about the order of iteration. These iterators are
+ // invalidated by mutation of the data structure.
+ UnsortedIterator unsorted_begin() const { return elements_.begin(); }
+ UnsortedIterator unsorted_end() const { return elements_.begin() + size(); }
+
+ // Accessor for comparator template argument.
+ Cmp *comparator() { return &cmp_; }
+
+ // This removes all elements. If Extract() or ExtractUnsorted() have been
+ // called, this will put it back in an empty but useable state.
+ void Reset();
+
+ private:
+ template <typename U>
+ void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11)
+
+ // elements_ can be in one of two states:
+ // elements_.size() <= limit_: elements_ is an unsorted vector of elements
+ // pushed so far.
+ // elements_.size() > limit_: The last element of elements_ is unused;
+ // the other elements of elements_ are an stl heap whose size is exactly
+ // limit_. In this case elements_.size() is exactly one greater than
+ // limit_, but don't use "elements_.size() == limit_ + 1" to check for
+ // that because you'll get a false positive if limit_ == size_t(-1).
+ std::vector<T> elements_;
+ size_t limit_; // Maximum number of elements to find
+ Cmp cmp_; // Greater-than comparison function
+ State state_ = UNORDERED;
+};
+
+// ----------------------------------------------------------------------
+// Implementations of non-inline functions
+
+template <class T, class Cmp>
+template <typename U>
+void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11)
+ if (limit_ == 0) {
+ if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
+ return;
+ }
+ if (state_ != HEAP_SORTED) {
+ elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11)
+ if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) {
+ // Easy case: we just pushed the new element back
+ } else {
+ // To maintain the BOTTOM_KNOWN state, we need to make sure that
+ // the element at position 0 is always the smallest. So we put
+ // the new element at position 0 and push the original bottom
+ // element in the back.
+ // Warning: this code is subtle.
+ using std::swap;
+ swap(elements_.front(), elements_.back());
+ }
+ if (elements_.size() == limit_ + 1) {
+ // Transition from unsorted vector to a heap.
+ std::make_heap(elements_.begin(), elements_.end(), cmp_);
+ if (dropped) *dropped = std::move(elements_.front());
+ std::pop_heap(elements_.begin(), elements_.end(), cmp_);
+ state_ = HEAP_SORTED;
+ }
+ } else {
+ // Only insert the new element if it is greater than the least element.
+ if (cmp_(v, elements_.front())) {
+ elements_.back() = std::forward<U>(v); // NOLINT(build/c++11)
+ std::push_heap(elements_.begin(), elements_.end(), cmp_);
+ if (dropped) *dropped = std::move(elements_.front());
+ std::pop_heap(elements_.begin(), elements_.end(), cmp_);
+ } else {
+ if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
+ }
+ }
+}
+
+template <class T, class Cmp>
+const T &TopN<T, Cmp>::peek_bottom() {
+ TFLITE_DCHECK(!empty());
+ if (state_ == UNORDERED) {
+ // We need to do a linear scan to find out the bottom element
+ int min_candidate = 0;
+ for (size_t i = 1; i < elements_.size(); ++i) {
+ if (cmp_(elements_[min_candidate], elements_[i])) {
+ min_candidate = i;
+ }
+ }
+ // By swapping the element at position 0 and the minimal
+ // element, we transition to the BOTTOM_KNOWN state
+ if (min_candidate != 0) {
+ using std::swap;
+ swap(elements_[0], elements_[min_candidate]);
+ }
+ state_ = BOTTOM_KNOWN;
+ }
+ return elements_.front();
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::Extract() {
+ auto out = new std::vector<T>;
+ out->swap(elements_);
+ if (state_ != HEAP_SORTED) {
+ std::sort(out->begin(), out->end(), cmp_);
+ } else {
+ out->pop_back();
+ std::sort_heap(out->begin(), out->end(), cmp_);
+ }
+ return out;
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() {
+ auto out = new std::vector<T>;
+ out->swap(elements_);
+ if (state_ == HEAP_SORTED) {
+ // Remove the limit_+1'th element.
+ out->pop_back();
+ }
+ return out;
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const {
+ auto out = new std::vector<T>;
+ ExtractNondestructive(out);
+ return out;
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const {
+ TFLITE_DCHECK(output);
+ *output = elements_;
+ if (state_ != HEAP_SORTED) {
+ std::sort(output->begin(), output->end(), cmp_);
+ } else {
+ output->pop_back();
+ std::sort_heap(output->begin(), output->end(), cmp_);
+ }
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const {
+ auto elements = new std::vector<T>;
+ ExtractUnsortedNondestructive(elements);
+ return elements;
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const {
+ TFLITE_DCHECK(output);
+ *output = elements_;
+ if (state_ == HEAP_SORTED) {
+ // Remove the limit_+1'th element.
+ output->pop_back();
+ }
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::Reset() {
+ elements_.clear();
+ state_ = UNORDERED;
+}
+
+} // namespace gtl
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index e38597495d..7a680f5c64 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -26,18 +26,12 @@ limitations under the License.
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
-#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
-#endif
#include "tensorflow/contrib/lite/profiling/profiler.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
-#ifdef TFLITE_MCU
-class NNAPIDelegate {};
-#endif
-
namespace {
TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node,
@@ -630,7 +624,6 @@ TfLiteStatus Interpreter::Invoke() {
}
TfLiteStatus status = kTfLiteOk;
-#ifndef TFLITE_MCU
if (nnapi_delegate_) {
if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
@@ -644,7 +637,6 @@ TfLiteStatus Interpreter::Invoke() {
return kTfLiteError;
}
}
-#endif
// Invocations are always done in node order.
// Note that calling Invoke repeatedly will cause the original memory plan to
@@ -902,17 +894,15 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
}
void Interpreter::UseNNAPI(bool enable) {
-#ifndef TFLITE_MCU
// TODO(aselle): This is a workaround for finding if NNAPI exists.
// We also need to make sure getLibraryHandle() is renamed to be NNAPI
// prefixed.
- if (!NNAPIExists()) enable = false;
+ if (!NNAPIDelegate::IsSupported()) enable = false;
if (!enable) {
nnapi_delegate_.reset();
} else if (!nnapi_delegate_) {
nnapi_delegate_.reset(new NNAPIDelegate);
}
-#endif
}
void Interpreter::SetNumThreads(int num_threads) {
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 329c98f91e..c5586475ec 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -8,6 +8,19 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+# Suppress warnings that are introduced by Eigen Tensor.
+EXTRA_EIGEN_COPTS = select({
+ "//tensorflow:ios": [
+ "-Wno-error=invalid-partial-specialization",
+ "-Wno-error=reorder",
+ ],
+ "//tensorflow:windows": [
+ "/DEIGEN_HAS_C99_MATH",
+ "/DEIGEN_AVOID_STL_ARRAY",
+ ],
+ "//conditions:default": ["-Wno-error=reorder"],
+})
+
tf_cc_test(
name = "optional_tensor_test",
size = "small",
@@ -49,13 +62,7 @@ cc_library(
hdrs = [
"eigen_support.h",
],
- copts = tflite_copts() + [
- "-Wno-error=reorder",
- ] + select({
- "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
- "//conditions:default": [
- ],
- }),
+ copts = tflite_copts() + EXTRA_EIGEN_COPTS,
deps = [
":op_macros",
"//tensorflow/contrib/lite:arena_planner",
@@ -209,14 +216,7 @@ cc_library(
"padding.h",
"register.h",
],
- # Suppress warnings that are introduced by Eigen Tensor.
- copts = tflite_copts() + [
- "-Wno-error=reorder",
- ] + select({
- "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
- "//conditions:default": [
- ],
- }),
+ copts = tflite_copts() + EXTRA_EIGEN_COPTS,
deps = [
":activation_functor",
":eigen_support",
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 6e13b8c667..817266a471 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -212,25 +212,25 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
- output->type = input->type;
-
// Currently only Float32 is supported
// TODO(ycling): Support other data types.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32);
+ output->type = input->type;
- // Currently, only support 4D `input` and 3D `alpha` with shape
- // (1, 1, channels).
- // TODO(impjdi): Support other cases where `alpha` is broadcastable
- // to `input`.
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]);
+ // PRelu (parameteric Relu) shares the same alpha value on "shared axis".
+ // This means it's always required to "broadcast" alpha values in PRelu.
+ TfLiteIntArray* output_size = nullptr;
+ TF_LITE_ENSURE_OK(
+ context, CalculateShapeForBroadcast(context, input, alpha, &output_size));
- return context->ResizeTensor(context, output,
- TfLiteIntArrayCopy(input->dims));
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+ // After broadcasting, the output shape should always be the same as the
+ // input shape.
+ TF_LITE_ENSURE(context, HaveSameShapes(input, output));
+
+ return kTfLiteOk;
}
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
@@ -524,33 +524,24 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
}
+template <typename T>
+T ApplyPrelu(T input, T alpha) {
+ return input >= 0.0 ? input : input * alpha;
+}
+
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
- const TfLiteTensor* output = GetOutput(context, node, 0);
-
+ TfLiteTensor* output = GetOutput(context, node, 0);
if (input->type != kTfLiteFloat32) {
context->ReportError(context, "Only float32 supported currently, got %d.",
input->type);
return kTfLiteError;
}
- TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
- const int batches = input->dims->data[0];
- const int height = input->dims->data[1];
- const int width = input->dims->data[2];
- const int channels = input->dims->data[3];
-
- TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
- TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels);
-
- const int n = batches * height * width * channels;
- for (int i = 0; i < n; ++i) {
- const float x = input->data.f[i];
- output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x;
- }
-
+ reference_ops::BroadcastBinaryFunction<float, float, float>(
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(alpha), GetTensorDims(alpha),
+ GetTensorData<float>(output), GetTensorDims(output), ApplyPrelu<float>);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 6f174763df..04c0263b78 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -256,10 +256,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
- TF_LITE_ENSURE(context, real_multiplier < 1.0);
- QuantizeMultiplierSmallerThanOneExp(
- real_multiplier, &data->output_multiplier, &data->output_shift);
- data->output_shift *= -1;
+
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
&data->output_activation_max);
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 0dcfc826fd..24633c2fd7 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -64,12 +64,6 @@ class BaseConvolutionOpModel : public SingleOpModel {
}
output_ = AddOutput(output);
- if (input.type != TensorType_FLOAT32) {
- // The following is required by quantized inference. It is the unittest's
- // responsibility to make sure the output scale falls into the correct
- // range.
- CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
- }
SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
CreateConv2DOptions(
@@ -441,6 +435,44 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantized) {
}));
}
+TEST_P(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) {
+ // output_multiplier = 1.0118
+ QuantizedConvolutionOpModel quant_op(
+ GetRegistration(), {TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128},
+ {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128},
+ {TensorType_UINT8, {}, -127, 128});
+ ConvolutionOpModel float_op(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
+ std::initializer_list<float> input = {
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ };
+ std::initializer_list<float> filter = {
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ };
+ std::initializer_list<float> bias = {1, 2, 3};
+
+ quant_op.SetInput(input);
+ quant_op.SetFilter(filter);
+ quant_op.SetBias(bias);
+ quant_op.Invoke();
+
+ float_op.SetInput(input);
+ float_op.SetFilter(filter);
+ float_op.SetBias(bias);
+ float_op.Invoke();
+
+ EXPECT_THAT(quant_op.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) {
QuantizedConvolutionOpModel m(GetRegistration(),
{TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64},
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index 0c532cac5a..d7bde0ff79 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -40,8 +40,8 @@ constexpr int kOutputTensorDetectionClasses = 1;
constexpr int kOutputTensorDetectionScores = 2;
constexpr int kOutputTensorNumDetections = 3;
-constexpr size_t kNumCoordBox = 4;
-constexpr size_t kBatchSize = 1;
+constexpr int kNumCoordBox = 4;
+constexpr int kBatchSize = 1;
// Object Detection model produces axis-aligned boxes in two formats:
// BoxCorner represents the upper right (xmin, ymin) and
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 3a855fe3dd..0d424071da 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -481,6 +481,9 @@ cc_library(
":darwin": [
":neon_tensor_utils",
],
+ ":darwin_x86_64": [
+ ":neon_tensor_utils",
+ ],
"//conditions:default": [
":portable_tensor_utils",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index 310a8980e6..eb4d0108bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -117,6 +117,9 @@ template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
"Only unsigned integer types handled.");
+#if defined(__GNUC__)
+ return integer_input ? __builtin_clz(integer_input) : 0;
+#else
const T one_in_leading_positive = static_cast<T>(1)
<< (std::numeric_limits<T>::digits - 1);
int leading_zeros = 0;
@@ -125,6 +128,7 @@ int CountLeadingZeros(T integer_input) {
++leading_zeros;
}
return leading_zeros;
+#endif
}
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
index d85e06a5d5..250872c422 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h
@@ -33,7 +33,7 @@ limitations under the License.
#include <functional>
#ifdef _WIN32
-#include <winbase.h>
+#include <windows.h>
#elif defined(__APPLE__)
#include <mach/mach_time.h>
#else
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 78567d52ea..6adb879c71 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -168,6 +168,18 @@ ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
return ArrayMap<Scalar>(data, rows, cols);
}
+// Copied from tensorflow/core/framework/tensor_types.h
+template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
+struct TTypes {
+ // Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
+ Eigen::Aligned>
+ Flat;
+ typedef Eigen::TensorMap<
+ Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>>
+ UnalignedConstMatrix;
+};
+
// TODO(b/62193649): this function is only needed as long
// as we have the --variable_batch hack.
template <typename Scalar, int N>
@@ -1018,10 +1030,10 @@ inline void FullyConnectedAsGEMV(
struct GemmlowpOutputPipeline {
typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
ColVectorMap;
- typedef std::tuple<
- gemmlowp::OutputStageBiasAddition<ColVectorMap>,
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
- gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8>
+ typedef std::tuple<gemmlowp::OutputStageBiasAddition<ColVectorMap>,
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent,
+ gemmlowp::OutputStageClamp,
+ gemmlowp::OutputStageSaturatingCastToUint8>
Pipeline;
static Pipeline MakeExp(const int32* bias_data, int output_rows,
int32 output_offset, int32 output_multiplier,
@@ -1030,11 +1042,10 @@ struct GemmlowpOutputPipeline {
ColVectorMap bias_vector(bias_data, output_rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
- quantize_down_stage;
+ gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage;
quantize_down_stage.result_offset_after_shift = output_offset;
quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
- quantize_down_stage.result_shift = -output_left_shift;
+ quantize_down_stage.result_exponent = output_left_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -2315,7 +2326,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
- const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bits =
+ CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
@@ -4023,7 +4035,7 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
// perform a division by the above-computed sum-of-exponentials.
int32 fixed_sum_of_exps = sum_of_exps.raw();
int headroom_plus_one =
- __builtin_clz(static_cast<uint32>(fixed_sum_of_exps));
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
// This is the number of bits to the left of the binary point above 1.0.
// Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
// no later adjustment will be needed.
@@ -4169,7 +4181,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl(
// required shift "ourselves" instead of using, say, Rescale.
FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw());
// z_a_pow_2 = input_integer_bits - z_a_headroom;
- int z_a_headroom_plus_1 = __builtin_clz(static_cast<uint32>(z_a.raw()));
+ int z_a_headroom_plus_1 = CountLeadingZeros(static_cast<uint32>(z_a.raw()));
FixedPoint0 r_a_tmp =
SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1));
const int32 r_a_raw =
@@ -4184,7 +4196,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl(
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
FixedPoint0 z_b = z_a * sqrt_half;
- int z_b_headroom = __builtin_clz(static_cast<uint32>(z_b.raw())) - 1;
+ int z_b_headroom = CountLeadingZeros(static_cast<uint32>(z_b.raw())) - 1;
const int32 r_b_raw =
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 6bd88b5596..e6ccd7a32c 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -21,6 +21,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#if defined(_MSC_VER)
+#define __restrict__ __restrict
+#endif
+
namespace tflite {
namespace tensor_utils {
@@ -38,10 +42,8 @@ bool PortableIsZeroVector(const float* vector, int v_size) {
}
void PortableSymmetricQuantizeFloats(const float* values, const int size,
- int8_t* quantized_values,
- float* __restrict__ min_value,
- float* __restrict__ max_value,
- float* __restrict__ scaling_factor) {
+ int8_t* quantized_values, float* min_value,
+ float* max_value, float* scaling_factor) {
auto minmax = std::minmax_element(values, values + size);
*min_value = *minmax.first;
*max_value = *minmax.second;
@@ -93,9 +95,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
for (row = 0; row < m_rows; ++row, result += result_stride) {
// Initialize the dot product sum for the row to 0.
int32_t dotprod = 0;
+#if defined(__GNUC__)
// Prefetch the row to cache.
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
3 /* temporal locality */);
+#endif
// For every block of 16 8-bit elements (128-bit register) from each row.
for (col = 0; col < m_cols; ++col, ++row_ptr) {
dotprod += (*row_ptr) * (vectors[col]);
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 714613b96e..7eb6fe34bc 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -322,8 +322,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
if (bias_data) {
acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
}
- acc = MultiplyByQuantizedMultiplierSmallerThanOneExp(
- acc, output_multiplier, kReverseShift * output_shift);
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ kReverseShift * output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -903,7 +903,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
++*output_shift;
}
TFLITE_DCHECK_GT(input, 0);
- const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
+ const unsigned max_left_shift_bits =
+ CountLeadingZeros(static_cast<uint32>(input)) - 1;
const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
*output_shift -= left_shift_bit_pairs;
@@ -4190,8 +4191,8 @@ inline void RankOneSelect(const D* input_condition_data,
}
// For easy implementation, the indices is always a vector of size-4 vectors.
-template <typename T, typename I>
-inline void SparseToDense(const std::vector<std::vector<I>>& indices,
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
const T* values, T default_value, T* output_data,
const Dims<4>& output_dims, bool value_is_scalar) {
const int value_count = indices.size();
@@ -4206,7 +4207,7 @@ inline void SparseToDense(const std::vector<std::vector<I>>& indices,
// condition within the loop every time.
if (value_is_scalar) {
for (int i = 0; i < value_count; ++i) {
- const std::vector<I>& index = indices[i];
+ const std::vector<TI>& index = indices[i];
TFLITE_DCHECK_EQ(index.size(), 4);
const T value = *values; // just use the first value.
output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
@@ -4217,7 +4218,7 @@ inline void SparseToDense(const std::vector<std::vector<I>>& indices,
// Go through the values and indices to fill the sparse values.
for (int i = 0; i < value_count; ++i) {
- const std::vector<I>& index = indices[i];
+ const std::vector<TI>& index = indices[i];
TFLITE_DCHECK_EQ(index.size(), 4);
const T value = values[i];
output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
@@ -4287,6 +4288,33 @@ inline void BroadcastLogical(const bool* input1_data,
}
}
+// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
+// generalized and efficient BroadcastBinaryFunction.
+//
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction(const T1* input1_data,
+ const Dims<4>& input1_dims,
+ const T2* input2_data,
+ const Dims<4>& input2_dims, R* output_data,
+ const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ func(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index e632728841..5ad0f4d232 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -110,6 +110,33 @@ TfLiteRegistration* Register_PACK();
TfLiteRegistration* Register_ONE_HOT();
TfLiteRegistration* Register_LOGICAL_OR();
+TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
+ context->ReportError(
+ context,
+ "Regular TensorFlow ops are not supported by this interpreter. Make sure "
+ "you invoke the Eager delegate before inference.");
+ return kTfLiteError;
+}
+
+const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
+ int version) const {
+ return MutableOpResolver::FindOp(op, version);
+}
+
+const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
+ int version) const {
+ // Return the NULL Op for all ops whose name start with "Eager:", allowing
+ // the interpreter to delegate their execution.
+ if (string(op).find("Eager:") == 0) {
+ static TfLiteRegistration null_op{
+ nullptr, nullptr, &UnsupportedTensorFlowOp,
+ nullptr, nullptr, BuiltinOperator_CUSTOM,
+ "Eager", 1};
+ return &null_op;
+ }
+ return MutableOpResolver::FindOp(op, version);
+}
+
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 940718d67e..0296152d68 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -26,6 +26,10 @@ namespace builtin {
class BuiltinOpResolver : public MutableOpResolver {
public:
BuiltinOpResolver();
+
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
};
} // namespace builtin
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index 7be5e66c16..fec2a6f0d9 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -187,7 +187,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return ResizeOutputShape(context, output_shape, output);
}
-template <typename T, typename I>
+template <typename T, typename TI>
TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
const TfLiteTensor* output_shape =
@@ -204,10 +204,10 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
const int num_indices = SizeOfDimension(indices, 0);
const bool value_is_scalar = NumDimensions(values) == 0;
- std::vector<std::vector<I>> indices_vector;
+ std::vector<std::vector<TI>> indices_vector;
indices_vector.reserve(num_indices);
- TF_LITE_ENSURE_OK(context, GetIndicesVector<I>(context, indices, num_indices,
- &indices_vector));
+ TF_LITE_ENSURE_OK(context, GetIndicesVector<TI>(context, indices, num_indices,
+ &indices_vector));
reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
*GetTensorData<T>(default_value),
GetTensorData<T>(output), GetTensorDims(output),
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
index af77f07474..5181a8f89a 100644
--- a/tensorflow/contrib/lite/kernels/tile.cc
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -87,8 +87,9 @@ std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
if (dimension == in_dimensions.size - 1) {
CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
out_data);
- return std::make_pair(dimension_size,
- dimension_size * multipliers[dimension]);
+ return std::make_pair(
+ dimension_size,
+ dimension_size * static_cast<int>(multipliers[dimension]));
}
int total_stride_size = 0, total_tiled_stride_size = 0;
const T* copy_from_data = in_data;
diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc
new file mode 100644
index 0000000000..fa9a3cd1d8
--- /dev/null
+++ b/tensorflow/contrib/lite/mmap_allocation.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+
+namespace tflite {
+
+MMAPAllocation::MMAPAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) {
+ mmap_fd_ = open(filename, O_RDONLY);
+ if (mmap_fd_ == -1) {
+ error_reporter_->Report("Could not open '%s'.", filename);
+ return;
+ }
+ struct stat sb;
+ fstat(mmap_fd_, &sb);
+ buffer_size_bytes_ = sb.st_size;
+ mmapped_buffer_ =
+ mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0);
+ if (mmapped_buffer_ == MAP_FAILED) {
+ error_reporter_->Report("Mmap of '%s' failed.", filename);
+ return;
+ }
+}
+
+MMAPAllocation::~MMAPAllocation() {
+ if (valid()) {
+ munmap(const_cast<void*>(mmapped_buffer_), buffer_size_bytes_);
+ }
+ if (mmap_fd_ != -1) close(mmap_fd_);
+}
+
+const void* MMAPAllocation::base() const { return mmapped_buffer_; }
+
+size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; }
+
+bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; }
+
+bool MMAPAllocation::IsSupported() { return true; }
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/mmap_allocation_disabled.cc b/tensorflow/contrib/lite/mmap_allocation_disabled.cc
new file mode 100644
index 0000000000..f3d4cf1a25
--- /dev/null
+++ b/tensorflow/contrib/lite/mmap_allocation_disabled.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/allocation.h"
+
+#include <cassert>
+
+namespace tflite {
+
+MMAPAllocation::MMAPAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : Allocation(error_reporter), mmapped_buffer_(nullptr) {
+ // The disabled variant should never be created.
+ assert(false);
+}
+
+MMAPAllocation::~MMAPAllocation() {}
+
+const void* MMAPAllocation::base() const { return nullptr; }
+
+size_t MMAPAllocation::bytes() const { return 0; }
+
+bool MMAPAllocation::valid() const { return false; }
+
+bool MMAPAllocation::IsSupported() { return false; }
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 5814cddc5b..9edf5ba38f 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
-#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
@@ -24,7 +23,9 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/model.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#endif
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
@@ -73,6 +74,7 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
return kTfLiteOk;
}
+#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
@@ -80,8 +82,8 @@ std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
ErrorReporter* error_reporter,
bool use_nnapi) {
std::unique_ptr<Allocation> allocation;
- if (mmap_file) {
- if (use_nnapi && NNAPIExists())
+ if (mmap_file && MMAPAllocation::IsSupported()) {
+ if (use_nnapi && NNAPIDelegate::IsSupported())
allocation.reset(new NNAPIAllocation(filename, error_reporter));
else
allocation.reset(new MMAPAllocation(filename, error_reporter));
@@ -120,6 +122,7 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
if (!model->initialized()) model.reset();
return model;
}
+#endif
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
@@ -781,6 +784,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_TRANSPOSE:
case BuiltinOperator_POW:
case BuiltinOperator_LOGICAL_OR:
+ case BuiltinOperator_LOGICAL_AND:
+ case BuiltinOperator_LOGICAL_NOT:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h
index 90260c8d62..3151192d92 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor.h
+++ b/tensorflow/contrib/lite/models/smartreply/predictor.h
@@ -65,9 +65,9 @@ struct SmartReplyConfig {
float backoff_confidence;
// Backoff responses are used when predicted responses cannot fulfill the
// list.
- const std::vector<std::string>& backoff_responses;
+ std::vector<std::string> backoff_responses;
- SmartReplyConfig(std::vector<std::string> backoff_responses)
+ SmartReplyConfig(const std::vector<std::string>& backoff_responses)
: num_response(kDefaultNumResponse),
backoff_confidence(kDefaultBackoffConfidence),
backoff_responses(backoff_responses) {}
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 1c06b29deb..c91f488175 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -624,6 +624,8 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_PACK:
case tflite::BuiltinOperator_LOGICAL_OR:
case tflite::BuiltinOperator_ONE_HOT:
+ case tflite::BuiltinOperator_LOGICAL_AND:
+ case tflite::BuiltinOperator_LOGICAL_NOT:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
@@ -789,4 +791,6 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
return kTfLiteOk;
}
+bool NNAPIDelegate::IsSupported() { return NNAPIExists(); }
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 8dc7d38a30..2bdb2cc5c8 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -19,9 +19,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
-class ANeuralNetworsModel;
+class ANeuralNetworksModel;
+class ANeuralNetworksMemory;
+class ANeuralNetworksCompilation;
namespace tflite {
@@ -54,6 +55,9 @@ class NNAPIDelegate {
// Run
TfLiteStatus Invoke(Interpreter* interpreter);
+ // Whether the current platform supports NNAPI delegation.
+ static bool IsSupported();
+
private:
// The NN API model handle
ANeuralNetworksModel* nn_model_ = nullptr;
diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
new file mode 100644
index 0000000000..efde72b1a7
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/nnapi_delegate.h"
+
+#include <cassert>
+
+namespace tflite {
+
+NNAPIAllocation::NNAPIAllocation(const char* filename,
+ ErrorReporter* error_reporter)
+ : MMAPAllocation(filename, error_reporter) {
+ // The disabled variant should never be created.
+ assert(false);
+}
+
+NNAPIAllocation::~NNAPIAllocation() {}
+
+NNAPIDelegate::~NNAPIDelegate() {}
+
+TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
+ return kTfLiteError;
+}
+
+TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
+ return kTfLiteError;
+}
+
+bool NNAPIDelegate::IsSupported() { return false; }
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 8ed98ddaf4..14f88b4c00 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -167,6 +167,8 @@ enum BuiltinOperator : byte {
PACK = 83,
LOGICAL_OR = 84,
ONE_HOT = 85,
+ LOGICAL_AND = 86,
+ LOGICAL_NOT = 87,
}
// Options for the builtin operators.
@@ -232,6 +234,8 @@ union BuiltinOptions {
PackOptions,
LogicalOrOptions,
OneHotOptions,
+ LogicalAndOptions,
+ LogicalNotOptions,
}
enum Padding : byte { SAME, VALID }
@@ -555,6 +559,12 @@ table OneHotOptions {
axis:int;
}
+table LogicalAndOptions {
+}
+
+table LogicalNotOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 4402f89b85..3efa153e2c 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -214,6 +214,12 @@ struct LogicalOrOptionsT;
struct OneHotOptions;
struct OneHotOptionsT;
+struct LogicalAndOptions;
+struct LogicalAndOptionsT;
+
+struct LogicalNotOptions;
+struct LogicalNotOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -365,11 +371,13 @@ enum BuiltinOperator {
BuiltinOperator_PACK = 83,
BuiltinOperator_LOGICAL_OR = 84,
BuiltinOperator_ONE_HOT = 85,
+ BuiltinOperator_LOGICAL_AND = 86,
+ BuiltinOperator_LOGICAL_NOT = 87,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_ONE_HOT
+ BuiltinOperator_MAX = BuiltinOperator_LOGICAL_NOT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -455,7 +463,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
BuiltinOperator_REDUCE_MAX,
BuiltinOperator_PACK,
BuiltinOperator_LOGICAL_OR,
- BuiltinOperator_ONE_HOT
+ BuiltinOperator_ONE_HOT,
+ BuiltinOperator_LOGICAL_AND,
+ BuiltinOperator_LOGICAL_NOT
};
return values;
}
@@ -548,6 +558,8 @@ inline const char **EnumNamesBuiltinOperator() {
"PACK",
"LOGICAL_OR",
"ONE_HOT",
+ "LOGICAL_AND",
+ "LOGICAL_NOT",
nullptr
};
return names;
@@ -621,11 +633,13 @@ enum BuiltinOptions {
BuiltinOptions_PackOptions = 59,
BuiltinOptions_LogicalOrOptions = 60,
BuiltinOptions_OneHotOptions = 61,
+ BuiltinOptions_LogicalAndOptions = 62,
+ BuiltinOptions_LogicalNotOptions = 63,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_OneHotOptions
+ BuiltinOptions_MAX = BuiltinOptions_LogicalNotOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -688,7 +702,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
BuiltinOptions_FakeQuantOptions,
BuiltinOptions_PackOptions,
BuiltinOptions_LogicalOrOptions,
- BuiltinOptions_OneHotOptions
+ BuiltinOptions_OneHotOptions,
+ BuiltinOptions_LogicalAndOptions,
+ BuiltinOptions_LogicalNotOptions
};
return values;
}
@@ -757,6 +773,8 @@ inline const char **EnumNamesBuiltinOptions() {
"PackOptions",
"LogicalOrOptions",
"OneHotOptions",
+ "LogicalAndOptions",
+ "LogicalNotOptions",
nullptr
};
return names;
@@ -1015,6 +1033,14 @@ template<> struct BuiltinOptionsTraits<OneHotOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
};
+template<> struct BuiltinOptionsTraits<LogicalAndOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions;
+};
+
+template<> struct BuiltinOptionsTraits<LogicalNotOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1534,6 +1560,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_OneHotOptions ?
reinterpret_cast<const OneHotOptionsT *>(value) : nullptr;
}
+ LogicalAndOptionsT *AsLogicalAndOptions() {
+ return type == BuiltinOptions_LogicalAndOptions ?
+ reinterpret_cast<LogicalAndOptionsT *>(value) : nullptr;
+ }
+ const LogicalAndOptionsT *AsLogicalAndOptions() const {
+ return type == BuiltinOptions_LogicalAndOptions ?
+ reinterpret_cast<const LogicalAndOptionsT *>(value) : nullptr;
+ }
+ LogicalNotOptionsT *AsLogicalNotOptions() {
+ return type == BuiltinOptions_LogicalNotOptions ?
+ reinterpret_cast<LogicalNotOptionsT *>(value) : nullptr;
+ }
+ const LogicalNotOptionsT *AsLogicalNotOptions() const {
+ return type == BuiltinOptions_LogicalNotOptions ?
+ reinterpret_cast<const LogicalNotOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5527,6 +5569,86 @@ inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(
flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct LogicalAndOptionsT : public flatbuffers::NativeTable {
+ typedef LogicalAndOptions TableType;
+ LogicalAndOptionsT() {
+ }
+};
+
+struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LogicalAndOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ LogicalAndOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<LogicalAndOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LogicalAndOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LogicalAndOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LogicalAndOptionsBuilder &operator=(const LogicalAndOptionsBuilder &);
+ flatbuffers::Offset<LogicalAndOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LogicalAndOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LogicalAndOptions> CreateLogicalAndOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LogicalAndOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<LogicalAndOptions> CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct LogicalNotOptionsT : public flatbuffers::NativeTable {
+ typedef LogicalNotOptions TableType;
+ LogicalNotOptionsT() {
+ }
+};
+
+struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LogicalNotOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ LogicalNotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<LogicalNotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LogicalNotOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit LogicalNotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LogicalNotOptionsBuilder &operator=(const LogicalNotOptionsBuilder &);
+ flatbuffers::Offset<LogicalNotOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<LogicalNotOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ LogicalNotOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5843,6 +5965,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const OneHotOptions *builtin_options_as_OneHotOptions() const {
return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr;
}
+ const LogicalAndOptions *builtin_options_as_LogicalAndOptions() const {
+ return builtin_options_type() == BuiltinOptions_LogicalAndOptions ? static_cast<const LogicalAndOptions *>(builtin_options()) : nullptr;
+ }
+ const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const {
+ return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast<const LogicalNotOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6118,6 +6246,14 @@ template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOption
return builtin_options_as_OneHotOptions();
}
+template<> inline const LogicalAndOptions *Operator::builtin_options_as<LogicalAndOptions>() const {
+ return builtin_options_as_LogicalAndOptions();
+}
+
+template<> inline const LogicalNotOptions *Operator::builtin_options_as<LogicalNotOptions>() const {
+ return builtin_options_as_LogicalNotOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -8259,6 +8395,52 @@ inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatB
_axis);
}
+inline LogicalAndOptionsT *LogicalAndOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new LogicalAndOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void LogicalAndOptions::UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<LogicalAndOptions> LogicalAndOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateLogicalAndOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<LogicalAndOptions> CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalAndOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateLogicalAndOptions(
+ _fbb);
+}
+
+inline LogicalNotOptionsT *LogicalNotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new LogicalNotOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void LogicalNotOptions::UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<LogicalNotOptions> LogicalNotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateLogicalNotOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<LogicalNotOptions> CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalNotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateLogicalNotOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8692,6 +8874,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_LogicalAndOptions: {
+ auto ptr = reinterpret_cast<const LogicalAndOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LogicalNotOptions: {
+ auto ptr = reinterpret_cast<const LogicalNotOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8954,6 +9144,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_LogicalAndOptions: {
+ auto ptr = reinterpret_cast<const LogicalAndOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_LogicalNotOptions: {
+ auto ptr = reinterpret_cast<const LogicalNotOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9204,6 +9402,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const OneHotOptionsT *>(value);
return CreateOneHotOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_LogicalAndOptions: {
+ auto ptr = reinterpret_cast<const LogicalAndOptionsT *>(value);
+ return CreateLogicalAndOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_LogicalNotOptions: {
+ auto ptr = reinterpret_cast<const LogicalNotOptionsT *>(value);
+ return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9454,6 +9660,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_LogicalAndOptions: {
+ value = new LogicalAndOptionsT(*reinterpret_cast<LogicalAndOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_LogicalNotOptions: {
+ value = new LogicalNotOptionsT(*reinterpret_cast<LogicalNotOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9766,6 +9980,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_LogicalAndOptions: {
+ auto ptr = reinterpret_cast<LogicalAndOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_LogicalNotOptions: {
+ auto ptr = reinterpret_cast<LogicalNotOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 6d03c0fd9e..1bbf918fd7 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -687,12 +687,20 @@ def make_relu6_tests(zip_path):
def make_prelu_tests(zip_path):
"""Make a set of tests to do PReLU."""
- test_parameters = [{
- # The canonical case for image processing is having a 4D `input` (NHWC)
- # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
- "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
- "shared_axes": [[1, 2], [1]],
- }]
+ test_parameters = [
+ {
+ # The canonical case for image processing is having a 4D `input`
+ # (NHWC)and `shared_axes`=[1, 2], so the alpha parameter is per
+ # channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ },
+ {
+ # 2D-3D example. Share the 2nd axis.
+ "input_shape": [[20, 20], [20, 20, 20]],
+ "shared_axes": [[1]],
+ }
+ ]
def build_graph(parameters):
"""Build the graph for the test case."""
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 106cbc1b8e..e475f256c0 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -86,9 +86,6 @@ std::map<string, string> kBrokenTests = {
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
- // PRelu only supports 4D input with (1, 1, channels) 3D alpha now.
- {R"(^\/prelu.*shared_axes=\[1\])", "75975192"},
-
// No support for axis!=0 in GatherV2.
{R"(^\/gather.*axis=1)", "76910444"},
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 378212cb74..8b41865985 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1940,6 +1940,21 @@ void ConvertLogicalOrOperator(const Model& model,
(*logical_or_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertCTCBeamSearchDecoderOperator(
+ const Model& model, const CTCBeamSearchDecoderOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op(op_name);
+ op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *op->add_input() = src_op.inputs[i];
+ }
+ (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
+ (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
+ (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2194,6 +2209,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertLogicalOrOperator(model,
static_cast<const LogicalOrOperator&>(src_op),
"LogicalOr", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
+ ConvertCTCBeamSearchDecoderOperator(
+ model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
+ "CTCBeamSearchDecoder", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 2f1bb8f0ad..527013bfa3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -377,6 +377,19 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
case OperatorType::kMean:
changed = HardcodeMinMaxFromFirstInput(model, op);
break;
+ case OperatorType::kSum:
+ // reduce_sum is expected to change the output range. Hence
+ // a fake_quant op is necessary in the output to minimize error. However
+ // in special circumstances like when computing expected value using
+ // reduce_sum the input range and the output range matches. Hence the
+ // below code would act as a fallback. If a fake_quant node is observed in
+ // the output that takes precendence over the hard coding logic below.
+ changed = HardcodeMinMaxFromFirstInput(model, op);
+ if (changed) {
+ LOG(WARNING) << "Using the input range for output in reduce_sum op."
+ << "This could have an impact on your model accuracy.";
+ }
+ break;
case OperatorType::kSelect:
changed = HardcodeMinMaxForSelect(model, op);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index f033ee013e..c8310161cb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -215,6 +215,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = on_value_type;
break;
}
+ case OperatorType::kCTCBeamSearchDecoder: {
+ CHECK_EQ(op->inputs.size(), 2);
+ // All outputs (sparse tensors) are int32s (although tf uses int64s)
+ // except the last one (log probabilities) is float.
+ const int output_size = op->outputs.size();
+ for (int i = 0; i < output_size - 1; ++i) {
+ model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32;
+ }
+ model->GetArray(op->outputs[output_size - 1]).data_type =
+ ArrayDataType::kFloat;
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index f6ce3b3ecb..b5a6554c1d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -50,7 +50,7 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
type == OperatorType::kTanh || type == OperatorType::kMul ||
- type == OperatorType::kBatchToSpaceND ||
+ type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum ||
type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice ||
@@ -61,9 +61,20 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
type == OperatorType::kArgMax || type == OperatorType::kRelu ||
- type == OperatorType::kRelu1 || type == OperatorType::kRelu6;
+ type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
+ type == OperatorType::kShape;
}
+// The quantized op allows output arrays of type float using
+// the attribute support_output_type_float_in_quantized_op
+bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) {
+ auto type = op.type;
+ if (type == OperatorType::kUnsupported) {
+ auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
+ return unsupported->support_output_type_float_in_quantized_op;
+ }
+ return false;
+}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
auto& array = model->GetArray(array_name);
// Normally we should have a MinMax recorded on this Array,
@@ -584,61 +595,67 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
}
// Quantize outputs, add Dequantize ops as needed on the outputs side
- for (std::size_t output_index = 0; output_index < op.outputs.size();
- output_index++) {
- ArrayDataType quantized_data_type;
- QuantizationParams quantization_params;
- if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
- &quantized_data_type,
- &quantization_params)) {
- changed = true;
- const auto& output = op.outputs[output_index];
- auto& output_array = model->GetArray(output);
-
- // Fix up the min/max information on the output array to match the chosen
- // quantization parameters.
- CHECK(output_array.minmax)
- << "Output array named " << output << " lacks minmax";
- auto& output_minmax = output_array.GetMinMax();
- FixMinMaxPostQuantization(this, quantized_data_type, quantization_params,
- &output_minmax);
-
- QuantizeArray(this, model, output, quantized_data_type,
- quantization_params);
-
- const auto& dequantized_output =
- AvailableArrayName(*model, output + "_dequantized");
- auto& dequantized_output_array =
- model->GetOrCreateArray(dequantized_output);
- dequantized_output_array.data_type = ArrayDataType::kFloat;
- dequantized_output_array.final_data_type = output_array.data_type;
- auto& dequantized_output_minmax =
- dequantized_output_array.GetOrCreateMinMax();
- dequantized_output_minmax.min = output_minmax.min;
- dequantized_output_minmax.max = output_minmax.max;
- for (const auto& other_op : model->operators) {
- for (auto& other_op_input : other_op->inputs) {
- if (other_op_input == output) {
- other_op_input = dequantized_output;
+ if (SupportOutputTypeFloatInQuantizedOp(op)) {
+ LOG(WARNING)
+ << HelpfulOperatorTypeName(op) << " is a quantized op"
+ << "but it has a model flag that sets the output arrays to float.";
+ } else {
+ for (std::size_t output_index = 0; output_index < op.outputs.size();
+ output_index++) {
+ QuantizationParams quantization_params;
+ ArrayDataType quantized_data_type;
+ if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& output = op.outputs[output_index];
+ auto& output_array = model->GetArray(output);
+
+ // Fix up the min/max information on the output array to match the
+ // chosen quantization parameters.
+ CHECK(output_array.minmax)
+ << "Output array named " << output << " lacks minmax";
+ auto& output_minmax = output_array.GetMinMax();
+ FixMinMaxPostQuantization(this, quantized_data_type,
+ quantization_params, &output_minmax);
+
+ QuantizeArray(this, model, output, quantized_data_type,
+ quantization_params);
+
+ const auto& dequantized_output =
+ AvailableArrayName(*model, output + "_dequantized");
+ auto& dequantized_output_array =
+ model->GetOrCreateArray(dequantized_output);
+ dequantized_output_array.data_type = ArrayDataType::kFloat;
+ dequantized_output_array.final_data_type = output_array.data_type;
+ auto& dequantized_output_minmax =
+ dequantized_output_array.GetOrCreateMinMax();
+ dequantized_output_minmax.min = output_minmax.min;
+ dequantized_output_minmax.max = output_minmax.max;
+ for (const auto& other_op : model->operators) {
+ for (auto& other_op_input : other_op->inputs) {
+ if (other_op_input == output) {
+ other_op_input = dequantized_output;
+ }
}
}
- }
- auto* dequantize_op = new DequantizeOperator;
- dequantize_op->inputs = {output};
- dequantize_op->outputs = {dequantized_output};
- for (int i = 0; i < model->flags.output_arrays_size(); i++) {
- if (model->flags.output_arrays(i) == output) {
- // TODO(b/78013785): never rename output arrays.
- AddMessageF(
- "Renaming output array %d after inserting dequant op %s: %s -> "
- "%s",
- i, LogName(*dequantize_op), model->flags.output_arrays(i),
- dequantized_output);
- model->flags.set_output_arrays(i, dequantized_output);
+ auto* dequantize_op = new DequantizeOperator;
+ dequantize_op->inputs = {output};
+ dequantize_op->outputs = {dequantized_output};
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == output) {
+ // TODO(b/78013785): never rename output arrays.
+ AddMessageF(
+ "Renaming output array %d after inserting dequant op %s: %s -> "
+ "%s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantized_output);
+ model->flags.set_output_arrays(i, dequantized_output);
+ }
}
+ const auto op_it = FindOp(*model, &op);
+ model->operators.emplace(op_it + 1, dequantize_op);
}
- const auto op_it = FindOp(*model, &op);
- model->operators.emplace(op_it + 1, dequantize_op);
}
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9a3db5c888..d8d331f3d4 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1049,6 +1049,8 @@ tensorflow::Status ConvertUnsupportedOperator(
static constexpr char kAttrOutputQuantized[] = "_output_quantized";
static constexpr char kAttrOutputTypes[] = "_output_types";
static constexpr char kAttrOutputShapes[] = "_output_shapes";
+ static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
+ "_support_output_type_float_in_quantized_op";
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
@@ -1060,9 +1062,15 @@ tensorflow::Status ConvertUnsupportedOperator(
op->tensorflow_op = node.op();
node.SerializeToString(&op->tensorflow_node_def);
model->operators.emplace_back(op);
+ // Parse if the op supports quantization
if (HasAttr(node, kAttrOutputQuantized)) {
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
}
+ // Parse if the quantized op allows output arrays of type float
+ if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
+ op->support_output_type_float_in_quantized_op =
+ GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
+ }
if (HasAttr(node, kAttrOutputTypes)) {
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
@@ -1854,6 +1862,34 @@ tensorflow::Status ConvertOneHotOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
+
+ auto* op = new CTCBeamSearchDecoderOperator;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+
+ op->beam_width =
+ HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
+ op->top_paths =
+ HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
+ op->merge_repeated = HasAttr(node, "merge_repeated")
+ ? GetBoolAttr(node, "merge_repeated")
+ : true;
+
+ // There are top_paths + 1 outputs.
+ op->outputs.push_back(node.name()); // Implicit :0.
+ for (int i = 0; i < op->top_paths; ++i) {
+ op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
+ }
+ model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -1888,6 +1924,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Const", ConvertConstOperator},
{"Conv2D", ConvertConvOperator},
{"Conv2DBackpropInput", ConvertTransposeConvOperator},
+ {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
{"DepthToSpace", ConvertDepthToSpaceOperator},
{"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
{"Div", ConvertSimpleOperator<DivOperator, 2>},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 7d0dbfcc05..18c78e32d0 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -148,6 +148,7 @@ enum class OperatorType : uint8 {
kLogicalAnd,
kLogicalNot,
kLogicalOr,
+ kCTCBeamSearchDecoder,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -438,6 +439,28 @@ struct ConvOperator : Operator {
int dilation_height_factor = 1;
};
+// CTCBeamSearchDecoder operator:
+//
+// Inputs:
+// inputs[0]: required: the logits.
+// inputs[1]: required: sequence length.
+// inputs[2]: optional: beam width.
+// inputs[3]: optional: top paths.
+// inputs[4]: optional: merge repeated.
+//
+// Outputs:
+// outputs[0]: deocoded.
+// outputs[1]: log probability.
+//
+// TensorFlow equivalent: CTCBeamSearchDecoder
+struct CTCBeamSearchDecoderOperator : Operator {
+ CTCBeamSearchDecoderOperator()
+ : Operator(OperatorType::kCTCBeamSearchDecoder) {}
+ int beam_width;
+ int top_paths;
+ bool merge_repeated = true;
+};
+
// Depthwise-separable convolution operator.
//
// Inputs:
@@ -1509,6 +1532,9 @@ struct TensorFlowUnsupportedOperator : Operator {
string tensorflow_node_def;
// A boolean indicating if the unsupported op should be treated as quantized.
bool quantized = false;
+ // A boolean indicating if the unsupported op output should allow float values
+ // in quantized mode.
+ bool support_output_type_float_in_quantized_op = false;
// Output data types
std::vector<ArrayDataType> output_data_types;
// Output shapes.
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 9380168f30..4ece561e97 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1070,6 +1070,27 @@ class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+class CTCBeamSearchDecoder
+ : public CustomOperator<CTCBeamSearchDecoderOperator> {
+ public:
+ using CustomOperator::CustomOperator;
+
+ void WriteOptions(const TocoOperator& op,
+ flexbuffers::Builder* fbb) const override {
+ fbb->Int("beam_width", op.beam_width);
+ fbb->Int("top_paths", op.top_paths);
+ fbb->Bool("merge_repeated", op.merge_repeated);
+ }
+
+ void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ op->beam_width = m["beam_width"].AsInt32();
+ op->top_paths = m["top_paths"].AsInt32();
+ op->merge_repeated = m["merge_repeated"].AsBool();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -1179,6 +1200,12 @@ class TensorFlowUnsupported : public BaseOperator {
break;
case flexbuffers::TYPE_BOOL:
(*attr)[key].set_b(value.AsBool());
+ if (string(key) == "_output_quantized") {
+ op->quantized = value.AsBool();
+ }
+ if (string(key) == "_support_output_type_float_in_quantized_op") {
+ op->support_output_type_float_in_quantized_op = value.AsBool();
+ }
break;
case flexbuffers::TYPE_VECTOR_INT: {
auto* list = (*attr)[key].mutable_list();
@@ -1301,6 +1328,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
// Custom Operators.
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
+ ops.emplace_back(new CTCBeamSearchDecoder(
+ "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED",
OperatorType::kUnsupported));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 384f7c118d..12fdbbf214 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -472,6 +472,20 @@ TEST_F(OperatorTest, BuiltinOneHot) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
+TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
+ CTCBeamSearchDecoderOperator op;
+ op.beam_width = 3;
+ op.top_paths = 2;
+ op.merge_repeated = false;
+ std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op =
+ SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER",
+ OperatorType::kCTCBeamSearchDecoder),
+ op);
+ EXPECT_EQ(op.beam_width, output_toco_op->beam_width);
+ EXPECT_EQ(op.top_paths, output_toco_op->top_paths);
+ EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc
index de76fd4032..14168fa33f 100644
--- a/tensorflow/contrib/lite/toco/toco_port.cc
+++ b/tensorflow/contrib/lite/toco/toco_port.cc
@@ -38,7 +38,8 @@ void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); }
} // namespace port
} // namespace toco
-#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__)
+#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && \
+ !defined(__ANDROID__) && !defined(_WIN32)
// Wrap Google file operations.
@@ -115,9 +116,12 @@ string JoinPath(const string& a, const string& b) {
} // namespace port
} // namespace toco
-#else // (__APPLE__ || __ANDROID__)
+#else // !PLATFORM_GOOGLE || __APPLE__ || __ANDROID__ || _WIN32
#include <fcntl.h>
+#if defined(_WIN32)
+#include <io.h> // for _close, _open, _read
+#endif
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
@@ -130,6 +134,19 @@ string JoinPath(const string& a, const string& b) {
namespace toco {
namespace port {
+#if defined(_WIN32)
+#define close _close
+#define open _open
+#define read _read
+#define O_RDONLY _O_RDONLY
+#define O_CREAT _O_CREAT
+#define O_WRONLY _O_WRONLY
+// Windows does not support the same set of file permissions as other platforms.
+constexpr int kFileCreateMode = _S_IREAD | _S_IWRITE;
+#else
+constexpr int kFileCreateMode = 0664;
+#endif // _WIN32
+
static bool port_initialized = false;
void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) {
@@ -209,7 +226,7 @@ tensorflow::Status GetContents(const string& path, string* output,
tensorflow::Status SetContents(const string& filename, const string& contents,
const file::Options& options) {
- int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664);
+ int fd = open(filename.c_str(), O_WRONLY | O_CREAT, kFileCreateMode);
if (fd == -1) {
return tensorflow::errors::Internal("can't open() for write");
}
@@ -243,4 +260,4 @@ string JoinPath(const string& base, const string& filename) {
} // namespace port
} // namespace toco
-#endif // (__APPLE || __ANDROID__)
+#endif // !PLATFORM_GOOGLE || __APPLE || __ANDROID__ || _WIN32
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 68155c7329..80df09eb08 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -404,6 +404,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
+ HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh
index 48953e2e38..448ae6d22e 100755
--- a/tensorflow/contrib/makefile/download_dependencies.sh
+++ b/tensorflow/contrib/makefile/download_dependencies.sh
@@ -30,7 +30,11 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE
GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
-PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
+# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow
+# 1.10 branch does not work. `make distclean` fails and blocks the build
+# process. For now we're hardcoding to the version which is used by
+# TensorFlow 1.9.
+PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz"
RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)"
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 06ab58188a..28a531dfec 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import tracking
@@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
- root.restore(core_saver.latest_checkpoint(checkpoint_directory))
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase):
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
index af3b2ad1b5..c2166594e5 100644
--- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
@@ -22,8 +22,8 @@ from __future__ import print_function
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.predictor import predictor
from tensorflow.python.framework import ops
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver
class ContribEstimatorPredictor(predictor.Predictor):
@@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
# pylint: disable=protected-access
model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
# pylint: enable=protected-access
- checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
config=config,
diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py
index f275bc15ad..7886744b3c 100644
--- a/tensorflow/contrib/predictor/predictor_factories.py
+++ b/tensorflow/contrib/predictor/predictor_factories.py
@@ -108,6 +108,8 @@ def from_estimator(estimator,
def from_saved_model(export_dir,
signature_def_key=None,
signature_def=None,
+ input_names=None,
+ output_names=None,
tags=None,
graph=None,
config=None):
@@ -121,6 +123,12 @@ def from_saved_model(export_dir,
signature_def: A `SignatureDef` proto specifying the inputs and outputs
for prediction. Only one of `signature_def_key` and `signature_def`
should be specified.
+ input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
+ that represent the input. The keys can be any string of the user's
+ choosing.
+ output_names: A dictionary mapping strings to `Tensor`s in the
+ `SavedModel` that represent the output. The keys can be any string of
+ the user's choosing.
tags: Optional. Tags that will be used to retrieve the correct
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
@@ -138,6 +146,8 @@ def from_saved_model(export_dir,
export_dir,
signature_def_key=signature_def_key,
signature_def=signature_def,
+ input_names=input_names,
+ output_names=output_names,
tags=tags,
graph=graph,
config=config)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 4fc315d901..903faeff11 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -261,6 +261,16 @@ def _FindLayersToQuantize(graph):
layer_output_pattern = graph_matcher.OneofPattern(
[batch_to_space_pattern, layer_pattern])
+
+ # For separable convolutions, we are looking for a conv, followed by a conv
+ # with no activations between the two.
+ sep_conv_pattern = graph_matcher.OpTypePattern(
+ '|'.join(_QUANTIZABLE_TYPES),
+ inputs=[
+ graph_matcher.OneofPattern([layer_output_pattern]),
+ graph_matcher.OpTypePattern('*')
+ ],
+ ordered_inputs=False)
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
'Mul',
inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
@@ -393,6 +403,17 @@ def _FindLayersToQuantize(graph):
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
+ # Look for separable convolutions here
+ sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
+ for match_result in sep_conv_matcher.match_graph(graph):
+ layer_op = match_result.get_op(layer_pattern)
+ weight_tensor = match_result.get_tensor(weight_identity_pattern)
+ activation_op = match_result.get_op(layer_pattern)
+ if layer_op not in matched_layer_set:
+ matched_layer_set.add(layer_op)
+ layer_matches.append(
+ _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
+
return layer_matches
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 92ca4a1b0c..98209fffb9 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -122,12 +122,67 @@ class QuantizeTest(test_util.TensorFlowTestCase):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+ # Check if output of bias add is quantized
+ quantization_node_name = 'FakeQuantWithMinMaxVars'
+ conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
+ quantization_node_name)
+ self.assertEqual(conv_quant.type, quantization_node_name)
+
+ for op in graph.get_operations():
+ if op.type == quantization_node_name:
+ quant_op = graph.get_operation_by_name(op.name)
+ # Scan through all FakeQuant operations, ensuring that the activation
+ # identity op isn't in the consumers of the operation.
+ consumers = []
+ for output in quant_op.outputs:
+ consumers.extend(output.consumers())
+ self.assertNotIn('test/relu6', [c.name for c in consumers])
+
+ def testInsertQuantOpInSeparableConv2d(self):
+ self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d)
+
+ def _TestInsertQuantOpInSeparableConv2d(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth))
+ conv = separable_conv2d(
+ input1,
+ 3, [5, 5],
+ stride=2,
+ depth_multiplier=1.0,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test/test')
+ node = math_ops.add(conv, input2, name='test/add')
+ node = nn_ops.relu6(node, name='test/relu6')
+ update_barrier = control_flow_ops.no_op(name='update_barrier')
+ with ops.control_dependencies([update_barrier]):
+ array_ops.identity(node, name='control_dependency')
+
+ quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
+ # Check if output of bias add is quantized
quantization_node_name = 'FakeQuantWithMinMaxVars'
conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
+ # Check if weights for both convs inside seperable conv are quantized
+ pointwise_weight_quant = graph.get_operation_by_name(
+ 'test/test/weights_quant/' + quantization_node_name)
+ self.assertEqual(pointwise_weight_quant.type, quantization_node_name)
+ depthwise_weight_quant = graph.get_operation_by_name(
+ 'test/test/separable_conv2d/weights_quant/' + quantization_node_name)
+ self.assertEqual(depthwise_weight_quant.type, quantization_node_name)
+
+ # Check if activations after first depthwise conv are quantized.
+ depthwise_act_quant = graph.get_operation_by_name(
+ 'test/test/separable_conv2d/act_quant/' + quantization_node_name)
+ self.assertEqual(depthwise_act_quant.type, quantization_node_name)
+
for op in graph.get_operations():
if op.type == quantization_node_name:
quant_op = graph.get_operation_by_name(op.name)
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 60e8368c37..fc0d22d112 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -85,11 +85,12 @@ cc_library(
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
+ ":test_utils",
":trt_allocator",
+ ":trt_conversion",
":trt_logging",
":trt_plugins",
":trt_resources",
- ":trt_conversion",
":utils",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@@ -184,6 +185,8 @@ py_library(
],
)
+# TODO(aaroey): this wrapper has been causing troubles of double linking, so
+# either get rid of it, or split to make it contain minimum dependencies.
tf_py_wrap_cc(
name = "wrap_conversion",
srcs = ["trt_conversion.i"],
@@ -192,6 +195,7 @@ tf_py_wrap_cc(
"//tensorflow/python:platform/base.i",
],
deps = [
+ ":test_utils",
":trt_conversion",
":trt_engine_op_kernel",
"//third_party/python_runtime:headers",
@@ -264,6 +268,7 @@ tf_cuda_library(
],
deps = [
":segment",
+ ":test_utils",
":trt_allocator",
":trt_plugins",
":trt_logging",
@@ -274,7 +279,6 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core:gpu_runtime",
"//tensorflow/core:framework_lite",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
@@ -412,4 +416,17 @@ cc_library(
srcs = ["convert/utils.cc"],
hdrs = ["convert/utils.h"],
copts = tf_copts(),
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "test_utils",
+ srcs = ["test/utils.cc"],
+ hdrs = ["test/utils.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "@com_googlesource_code_re2//:re2",
+ ],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 3383f6bc9b..21ec8b0b30 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <map>
#include <set>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -29,9 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
@@ -195,20 +194,44 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
return tensorflow::Status::OK();
}
-// Entry function from Python.
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
int precision_mode, int minimum_segment_size, bool is_dyn_op,
int max_cached_engines, std::vector<int> cached_engine_batches) {
- // optimization pass
+ // Create GrapplerItem.
tensorflow::grappler::GrapplerItem item;
item.fetch = output_names;
item.graph = graph_def;
- // grappler requires a virtual cluster with a proper GPU device
- // in order to calculate flops>0 or fails with FATAL
- // We add numbers from a Pascal card here to have flops>0
+
+ // TODO(aaroey): we should have used single machine cluster like the
+ // following, but the problem is then wrap_conversion will depend on
+ // direct_session and cause double linking problems. To fix this we need to
+ // fix or get rid of the swig dependency. Here we use VirtualCluster
+ // as a work around, and we need to create a session to initialize the
+ // underlying device before calling this method.
+#if 0
+ // Create single machine cluster. Note that this will create a session and
+ // initialize the gpu devices.
+ const int num_cpu_cores =
+ tensorflow::grappler::GetNumAvailableLogicalCPUCores();
+ const int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
+ VLOG(2) << "cpu_cores: " << num_cpu_cores;
+ VLOG(2) << "gpus: " << num_gpus;
+ const int timeout_s = 60 * 10;
+ std::unique_ptr<tensorflow::grappler::Cluster> cluster(
+ new tensorflow::grappler::SingleMachine(
+ timeout_s, num_cpu_cores, num_gpus));
+ // These settings are the defaults in tensorflow/python/grappler/cluster.py.
+ cluster->DisableDetailedStats(true);
+ cluster->AllowSoftPlacement(true);
+ cluster->SetNumWarmupSteps(10);
+ TF_RETURN_IF_ERROR(cluster->Provision());
+#else
+ // Create virtual cluster. Grappler requires a virtual cluster with a proper
+ // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode.
+ // We add numbers from a Pascal card here to have flops>0.
tensorflow::DeviceProperties device_properties;
device_properties.set_type("GPU");
device_properties.mutable_environment()->insert({"architecture", "6"});
@@ -217,47 +240,43 @@ tensorflow::Status ConvertGraphDefToTensorRT(
std::unique_ptr<tensorflow::grappler::Cluster> cluster(
new tensorflow::grappler::VirtualCluster(
{{"/GPU:0", device_properties}}));
+#endif
- // single machine
- int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
- int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
- VLOG(2) << "cpu_cores: " << num_cpu_cores;
- VLOG(2) << "gpus: " << num_gpus;
+ // Create RewriterConfig.
tensorflow::RewriterConfig rw_cfg;
- // use only const folding and layout for the time being since new optimizers
- // break the graph for us
+ // TODO(aaroey): use only const folding and layout for the time being since
+ // new optimizers break the graph for trt.
rw_cfg.add_optimizers("constfold");
rw_cfg.add_optimizers("layout");
- rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
+ auto optimizer = rw_cfg.add_custom_optimizers();
+ optimizer->set_name("TensorRTOptimizer");
+ auto& parameters = *(optimizer->mutable_parameter_map());
+ parameters["minimum_segment_size"].set_i(minimum_segment_size);
+ parameters["max_batch_size"].set_i(max_batch_size);
+ parameters["is_dynamic_op"].set_b(is_dyn_op);
+ parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes);
+ TF_RETURN_IF_ERROR(GetPrecisionModeName(
+ precision_mode, parameters["precision_mode"].mutable_s()));
+ parameters["maximum_cached_engines"].set_i(max_cached_engines);
+ if (!cached_engine_batches.empty()) {
+ auto list = parameters["cached_engine_batches"].mutable_list();
+ for (const int batch : cached_engine_batches) {
+ list->add_i(batch);
+ }
+ }
+
+ // Run optimizer.
tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg);
- tensorflow::GraphDef gdef;
- TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef));
- item.graph = gdef;
-
- // AJ refactoring shape inference through grappler/GraphProperties.
- tensorflow::grappler::GraphProperties static_graph_properties(item);
- TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
- // Build full graph
- ConversionParams cp;
- cp.input_graph_def = &gdef;
- cp.output_names = &output_names;
- cp.max_batch_size = max_batch_size;
- cp.output_graph_def = new_graph_def;
- cp.precision_mode = precision_mode;
- cp.is_dyn_op = is_dyn_op;
- cp.max_cached_engines = max_cached_engines;
- cp.cached_engine_batches = cached_engine_batches;
- cp.minimum_segment_size = minimum_segment_size;
- cp.graph_properties = &static_graph_properties;
- cp.max_workspace_size_bytes = max_workspace_size_bytes;
+ TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def));
+
if (VLOG_IS_ON(5)) {
std::fstream f;
f.open("TRTConversionInput.pb",
std::fstream::out | std::fstream::binary | std::fstream::trunc);
- f << gdef.SerializeAsString();
+ f << new_graph_def->SerializeAsString();
f.close();
}
- return ConvertAfterShapes(cp);
+ return Status::OK();
}
// Function to get subsegment information structure.
@@ -268,11 +287,10 @@ tensorflow::Status GetEngineInfo(
const std::unordered_map<string, tensorflow::Node*>& node_map,
const std::vector<tensorflow::Node*>& reverse_topo_order,
EngineInfo* info) {
- std::vector<int> subgraph_node_ids;
+ std::vector<int> subgraph_node_ids; // Topologically sorted node ids.
+ std::set<string> subgraph_node_names = segment_nodes;
std::set<int> added_const_node_ids; // Used to prevent double insertion.
std::set<string> segment_devices;
- int input_port = 0;
- int output_port = 0;
// Map from src_node_name+port to the unique port numbers of the TRT op, where
// the src_node_name is the name of the source node of the input/output
@@ -280,13 +298,12 @@ tensorflow::Status GetEngineInfo(
// input/output edges must be in different split of the graph.
// TODO(aaroey): consider using node id and port instead.
// TODO(aaroey): using topo order instead of reverting reverse topo order.
- std::unordered_map<string, int> created_edges;
+ std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
++it) {
const auto& node_name = (*it)->name();
-
if (segment_nodes.count(node_name) == 0) continue;
- auto node = node_map.at(node_name);
+ auto node = *it;
auto node_device = node->requested_device();
if (!node_device.empty()) {
segment_devices.insert(node_device);
@@ -299,64 +316,93 @@ tensorflow::Status GetEngineInfo(
}
}
const int node_id = node->id();
+ subgraph_node_ids.push_back(node_id);
+ // Create input connections.
for (const auto edge : node->in_edges()) {
auto input_node = edge->src();
- if (segment_nodes.count(input_node->name()) == 0 &&
- !edge->IsControlEdge() && !input_node->IsSource()) {
- // Add constant input node into the segment. We don't care if it has
- // other output edges going into other engines or TF nodes. Since we add
- // it only to the subsegment node list, not the subsegment itself, it
- // won't be removed from the graph. If it doesn't have any edges, TF
- // will prune it out.
- if (input_node->type_string() == "Const") {
- if (added_const_node_ids.count(input_node->id()) == 0) {
- added_const_node_ids.insert(input_node->id());
- subgraph_node_ids.push_back(input_node->id());
- }
+ if (input_node->IsSource() || segment_nodes.count(input_node->name())) {
+ continue;
+ }
+ if (edge->IsControlEdge()) {
+ // Control input.
+ info->connections.emplace_back(input_node->name(), input_node->id(),
+ node_name, node_id,
+ /*input_edge=*/true);
+ } else if (input_node->type_string() == "Const") {
+ // Add constant data input nodes into the segment graphdef (thus also in
+ // the engine). We don't care if it has other output edges going into
+ // other engines or TF nodes. Since we add it only to the segment
+ // graphdef, not the segment itself, it won't be removed from the graph.
+ // If it doesn't have any edges, TF will prune it out.
+ //
+ // Note that the segmenter already ensure that the constant data input
+ // is valid and suppported by the engine.
+ if (!added_const_node_ids.insert(input_node->id()).second) {
+ // Already added before.
+ continue;
+ }
+ VLOG(1) << "Adding const node " << input_node->name();
+ QCHECK(subgraph_node_names.insert(input_node->name()).second);
+ // Since we already add (duplicate) the const input node to the segment
+ // graphdef, it's now not a data dependency any more, but to make the
+ // dependency correct we still add a control dependency.
+ info->connections.emplace_back(input_node->name(), input_node->id(),
+ node_name, node_id,
+ /*input_edge=*/true);
+ } else {
+ // Non-const data input.
+ int port = Graph::kControlSlot - 1;
+ // Use the source non-segment node name/port as key.
+ const string s = StrCat(input_node->name(), ":", edge->src_output());
+ VLOG(1) << "Input edge = " << s;
+ if (input_to_engine_port.count(s)) {
+ port = input_to_engine_port.at(s);
} else {
- string s(input_node->name());
- StrAppend(&s, ":", edge->src_output());
- VLOG(1) << "Input edge = " << s;
- int port = input_port;
- if (created_edges.count(s)) {
- port = created_edges.at(s);
- } else {
- created_edges.insert({s, port});
- input_port++;
- }
- info->connections.emplace_back(input_node->name(), input_node->id(),
- edge->src_output(), node_name, node_id,
- edge->dst_input(), true, port);
+ port = input_to_engine_port.size();
+ input_to_engine_port.insert({s, port});
}
+ info->connections.emplace_back(
+ input_node->name(), input_node->id(), edge->src_output(), node_name,
+ node_id, edge->dst_input(), /*input_edge=*/true, port);
}
}
- // We need to add possible const input nodes before adding this node in
- // order to keep the topological order.
- subgraph_node_ids.push_back(node_id);
+ // Create output connections.
for (const auto edge : node->out_edges()) {
auto output_node = edge->dst();
- if (segment_nodes.count(output_node->name()) == 0 &&
- !edge->IsControlEdge() && !output_node->IsSink()) {
- string s(node_name);
- StrAppend(&s, ":", edge->src_output());
+ if (output_node->IsSink() || segment_nodes.count(output_node->name())) {
+ continue;
+ }
+ if (edge->IsControlEdge()) {
+ // Control output.
+ info->connections.emplace_back(output_node->name(), output_node->id(),
+ node_name, node_id,
+ /*input_edge=*/false);
+ } else {
+ // Data output.
+ int port = Graph::kControlSlot - 1;
+ // Use the source segment node name/port as key.
+ const string s = StrCat(node_name, ":", edge->src_output());
VLOG(1) << "Output edge = " << s;
- int port = output_port;
- if (created_edges.count(s)) {
- port = created_edges.at(s);
+ if (output_to_engine_port.count(s)) {
+ port = output_to_engine_port.at(s);
} else {
- created_edges.insert({s, port});
- output_port++;
+ port = output_to_engine_port.size();
+ output_to_engine_port.insert({s, port});
}
- info->connections.emplace_back(output_node->name(), output_node->id(),
- edge->dst_input(), node_name, node_id,
- edge->src_output(), false, port);
+ info->connections.emplace_back(
+ output_node->name(), output_node->id(), edge->dst_input(),
+ node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
}
}
- }
+ } // For each segment node in topological order.
+ // Construct the const nodes first.
+ subgraph_node_ids.insert(subgraph_node_ids.begin(),
+ added_const_node_ids.begin(),
+ added_const_node_ids.end());
TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
- g, graph_properties, subgraph_node_ids, &info->connections,
- &info->segment_graph_def, &info->engine_name));
+ g, graph_properties, subgraph_node_names, subgraph_node_ids,
+ &info->connections, &info->segment_graph_def, &info->engine_name));
// TODO(sami): This should not happen once segmenter is updated.
if (segment_devices.size() == 1) {
info->device = *segment_devices.begin();
@@ -366,94 +412,137 @@ tensorflow::Status GetEngineInfo(
<< "but this shouldn't have happened";
info->device = *segment_devices.begin();
} else {
- VLOG(1) << "Segment devices size is 0";
+ LOG(ERROR) << "Can't find a device placement for the op!";
}
return Status::OK();
}
-// Function to insert a TRT node into the graph. The graph is not modified if
-// the returned status is not ok.
-// 'alloc' is only used for creating static engine.
-tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
- const std::vector<EngineInfo>& infos, int pos,
+// Helper function to update edge connection from the removed node to the
+// engine node. If an outside node is gone, it must have been absorbed into
+// an engine node. Find the engine node.
+void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
+ const size_t my_engine_id,
+ const std::vector<Node*>& engine_nodes,
+ const bool is_input_edge, const string& node_name,
+ tensorflow::Node** node, int* port) {
+ for (size_t t = 0; t < infos.size(); ++t) {
+ if (t == my_engine_id) {
+ continue;
+ }
+ const auto& info = infos.at(t);
+ for (const auto& eng_conn : info.connections) {
+ // If the connection being updated is an input connection, the source of
+ // the connection must be an output connection of another engine. And vise
+ // versa.
+ if (is_input_edge == eng_conn.is_input_edge) continue;
+ if (eng_conn.inside_node_name == node_name &&
+ eng_conn.inside_port == *port) {
+ *node = CHECK_NOTNULL(engine_nodes[t]);
+ QCHECK_EQ(info.engine_name, (**node).name())
+ << "Engine name mismatch: " << info.engine_name << " vs "
+ << (**node).name();
+ *port = eng_conn.port_number;
+ return;
+ }
+ }
+ }
+ LOG(FATAL) << "Node " << (**node).name() << " not found in any engine.";
+}
+
+// Function to insert a TRT engine node into the graph.
+// Create engine nodes in the following way:
+// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
+// 2. When an engine node is created, add it into the graph with necessary
+// re-wiring.
+// 2.1. If the outside connected node is existing, connect the engine
+// node to it.
+// 2.2. If the outside connected node is gone, it must have been absorted
+// into another engine node (which was processed before the processing
+// one). Connect to the pre-existing engine node instead.
+// 3. In this way, we ensure the graph is topologically sort-able after each
+// invocation of CreateTRTNode().
+tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
+ int max_batch_size, tensorflow::Graph* graph,
nvinfer1::IGpuAllocator* alloc,
- int max_batch_size) {
+ std::vector<Node*>* engine_nodes) {
const auto& info = infos.at(pos);
+ TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail");
std::vector<tensorflow::TensorShapeProto> output_shape_protos;
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
std::vector<tensorflow::PartialTensorShape> input_shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
+ std::vector<tensorflow::Node*> input_nodes;
+ std::vector<tensorflow::Node*> control_input_nodes;
+ std::unordered_set<string> control_input_names;
std::vector<tensorflow::DataType> out_types;
- VLOG(1) << "Processing " << info.engine_name;
- // Update the shape and data types of input/output nodes, and find all unique
- // inputs.
+ VLOG(1) << "Processing " << info.engine_name;
+ // Collect needed info for creating the engine node in the graph
for (const auto& conn : info.connections) {
- if (!conn.is_input_edge) {
- // Set the shapes and data types of output edge.
- tensorflow::TensorShapeProto out_shape;
- // shape of the output node inside segment
- conn.inside_shape.AsProto(&out_shape);
- if (output_shape_protos.size() <= conn.port_number) {
- output_shape_protos.resize(conn.port_number + 1);
- out_types.resize(conn.port_number + 1);
+ // Control edges
+ if (conn.is_control_edge()) {
+ // Skip control outputs for now. control output info are not needed for
+ // node creation and will be processed later.
+ if (!conn.is_input_edge) continue;
+
+ // Rewrire control input if it's not found in original graph.
+ tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
+ int port = tensorflow::Graph::kControlSlot;
+ if (!input_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+ conn.outside_node_name, &input_node, &port);
+ QCHECK_EQ(Graph::kControlSlot, port);
}
- output_shape_protos.at(conn.port_number) = out_shape;
- out_types.at(conn.port_number) = conn.connection_type;
- continue;
- }
-
- // Set the shapes and data types of input edge.
- tensorflow::TensorShapeProto in_shape;
- conn.outside_shape.AsProto(&in_shape);
- if (input_shape_protos.size() <= conn.port_number) {
- input_shape_protos.resize(conn.port_number + 1);
- input_shapes.resize(conn.port_number + 1);
- }
- input_shape_protos.at(conn.port_number) = in_shape;
- input_shapes.at(conn.port_number) = conn.outside_shape;
-
- string input_node = conn.outside_node_name;
- int input_port = conn.outside_port;
- bool found_engine = false;
- // Rewire the inputs to other engines if they contain original input node.
- // Note that we use the information of the engine here, not the information
- // of the created TRT nodes, so we're able to find all the connections to
- // any other engines beforehand.
- for (size_t t = 0; t < infos.size(); ++t) {
- if (t == pos) continue;
- auto& engine_info = infos.at(t);
- for (const auto& eng_conn : engine_info.connections) {
- if (eng_conn.is_input_edge) continue;
- if (eng_conn.inside_node_name == input_node) {
- input_node = engine_info.engine_name;
- if (eng_conn.inside_port == input_port) {
- input_port = eng_conn.port_number;
- found_engine = true;
- break;
- }
- }
+ if (!control_input_names.insert(input_node->name()).second) {
+ continue;
}
- if (found_engine) break;
- }
- VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
- << info.engine_name << ":" << inputs.size();
- // Skip duplicate inputs.
- // TODO(aaroey): use std::find instead. GetEngineInfo already remove
- // duplicate connections, so here we should never find any duplicate?
- bool new_input = true;
- for (const auto& inp : inputs) {
- if (inp.node == input_node && inp.index == input_port) {
- new_input = false;
- break;
+ control_input_nodes.push_back(input_node);
+ VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
+ << info.engine_name;
+ } else {
+ // Data edges
+ if (!conn.is_input_edge) {
+ // Set the shapes and data types of output edge.
+ tensorflow::TensorShapeProto out_shape;
+ // shape of the output node inside segment
+ conn.inside_shape.AsProto(&out_shape);
+ if (output_shape_protos.size() <= conn.port_number) {
+ output_shape_protos.resize(conn.port_number + 1);
+ out_types.resize(conn.port_number + 1);
+ }
+ output_shape_protos.at(conn.port_number) = out_shape;
+ out_types.at(conn.port_number) = conn.connection_type;
+ } else {
+ // Set the shapes and data types of input edge.
+ tensorflow::TensorShapeProto in_shape;
+ conn.outside_shape.AsProto(&in_shape);
+ if (input_shape_protos.size() <= conn.port_number) {
+ input_shape_protos.resize(conn.port_number + 1);
+ input_shapes.resize(conn.port_number + 1);
+ }
+ input_shape_protos.at(conn.port_number) = in_shape;
+ input_shapes.at(conn.port_number) = conn.outside_shape;
+
+ // Rewrire data input if it's not found in original graph.
+ tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id);
+ int port = conn.outside_port;
+ if (!input_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
+ conn.outside_node_name, &input_node, &port);
+ }
+ if (std::find_if(
+ std::begin(inputs), std::end(inputs),
+ [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
+ return inp.node == input_node->name() && inp.index == port;
+ }) == std::end(inputs)) {
+ inputs.emplace_back(input_node->name(), port, conn.connection_type);
+ input_nodes.push_back(CHECK_NOTNULL(input_node));
+ VLOG(1) << "Engine Input " << input_node->name() << ":" << port
+ << " -> " << info.engine_name << ":" << inputs.size() - 1;
+ }
}
}
- if (new_input) {
- inputs.emplace_back(input_node, input_port, conn.connection_type);
- }
}
-
- // Build the engine and get its serialized representation.
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic ||
info.precision_mode == INT8MODE) {
@@ -485,21 +574,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// TODO(aaroey): use enum instead, and add a helper method to do the
// conversion.
string prec_string;
- switch (info.precision_mode) {
- case FP32MODE:
- prec_string = "FP32";
- break;
- case FP16MODE:
- prec_string = "FP16";
- break;
- case INT8MODE:
- prec_string = "INT8";
- if (!TRTResourceManager::instance()->getManager("TRTCalibration")) {
- LOG(ERROR) << "Failed to construct calibration storage";
- }
- break;
- default:
- return tensorflow::errors::OutOfRange("Unknown precision mode");
+ TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string));
+ if (info.precision_mode == INT8MODE &&
+ !TRTResourceManager::instance()->getManager("TRTCalibration")) {
+ LOG(ERROR) << "Failed to construct calibration storage";
}
tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
if (!info.device.empty()) node_builder.Device(info.device);
@@ -511,6 +589,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << ins;
}
node_builder.Input(inputs);
+ for (const string& c : control_input_names) {
+ node_builder.ControlInput(c);
+ }
+
if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
info.cached_engine_batches.size()) {
LOG(WARNING) << "Cached engine batches are ignored for static engines";
@@ -539,34 +621,55 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
// Up until this point, graph is not modified. If we return !status.ok() from
// here, this segment will be skipped
+ // TODO(aaroey): let it return proper error status for the following logic
+ // instead of checking fail.
tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
+ (*engine_nodes)[pos] = engine_node;
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return status;
}
+ // Add control input and input edges to the engine node.
+ for (const auto in : control_input_nodes) {
+ VLOG(1) << "Connecting control edge from " << in->name() << " to "
+ << engine_node->name();
+ graph->AddControlEdge(in, engine_node);
+ }
+ VLOG(1) << "input_nodes size = " << input_nodes.size();
+ for (int i = 0; i < input_nodes.size(); ++i) {
+ Node* n = CHECK_NOTNULL(input_nodes[i]);
+ const auto& in = inputs[i];
+ VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
+ << " to " << engine_node->name() << ":" << i;
+ graph->AddEdge(n, in.index, engine_node, i);
+ }
+
// Updates the inputs of output edges destination nodes, and point them to the
// engine node.
for (auto& conn : info.connections) {
- if (conn.is_input_edge) continue;
- VLOG(1) << " Updating DBG " << engine_node->name() << " out_port "
- << conn.port_number << " out_id " << conn.outside_id
- << " name=" << conn.outside_node_name;
- auto dst_node = graph->FindNodeId(conn.outside_id);
- // dst_node can only be removed if it is an input node of another engine.
- // In this case, other engines input edge is updated in nodedef to point to
- // this engine. Even though edge doesn't exists in the graph, when it is
- // deserialized again, correct edges will be constructed. This is a problem
- // of graph->AddNode().
- if (!dst_node) continue;
+ if (conn.is_input_edge) {
+ continue;
+ }
+ tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id);
+ int port = conn.outside_port;
+ if (!output_node) {
+ UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
+ conn.outside_node_name, &output_node, &port);
+ }
VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
- << " to " << dst_node->name() << ":" << conn.outside_port;
- auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node,
- conn.outside_port);
- CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":"
- << conn.port_number << " -> " << dst_node->name() << ":"
- << conn.outside_port;
+ << " to " << output_node->name() << ":" << port;
+ if (conn.is_control_edge()) {
+ QCHECK_EQ(Graph::kControlSlot, port);
+ graph->AddControlEdge(engine_node, output_node);
+ } else {
+ auto new_edge =
+ graph->AddEdge(engine_node, conn.port_number, output_node, port);
+ QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name()
+ << ":" << conn.port_number << " -> "
+ << output_node->name() << ":" << conn.outside_port;
+ }
}
- return status;
+ return Status::OK();
}
// Function to construct a funcdef from the segment and add it to the graph.
@@ -666,72 +769,36 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
}
std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
- ConversionParams& params, EngineInfo& engine) {
+ const ConversionParams& params, const EngineInfo& engine) {
int cuda_device_id = -1;
- auto check_device_id = [](int tfid) -> int {
- tensorflow::TfGpuId tf_gpu_id(tfid);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
- if (s.ok()) {
- VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
- << cuda_gpu_id.value();
- return cuda_gpu_id.value();
- }
- VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s;
- return -1;
- };
tensorflow::Allocator* dev_allocator = nullptr;
- // we need to us PM here since in python path there is no way to get
- // to allocators.
- // TODO(sami): when grappler devices become available else path will not be
- // necessary
- auto pm = tensorflow::GPUProcessState::singleton();
- if (params.cluster) { // get allocator
- tensorflow::Device* device = nullptr;
- if (params.cluster->GetDeviceSet()) {
- device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device);
+ if (params.cluster) {
+ std::vector<tensorflow::Device*> devices;
+ if (!engine.device.empty() && params.cluster->GetDeviceSet()) {
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
+ parsed_name.has_id) {
+ params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name,
+ &devices);
+ }
}
- if (device) {
+ if (!devices.empty()) {
+ if (devices.size() > 1) {
+ string msg = "Found multiple matching devices using name '";
+ StrAppend(&msg, engine.device, "': ");
+ for (auto d : devices) StrAppend(&msg, d->name(), ", ");
+ StrAppend(&msg, ". Will get the allocator from first one.");
+ LOG(WARNING) << msg;
+ }
tensorflow::AllocatorAttributes alloc_attr;
- dev_allocator = device->GetAllocator(alloc_attr);
- VLOG(1) << "Using allocator " << dev_allocator->Name();
+ cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id;
+ dev_allocator = devices[0]->GetAllocator(alloc_attr);
+ VLOG(1) << "Using allocator " << dev_allocator->Name()
+ << " and cuda_device_id " << cuda_device_id;
} else {
LOG(WARNING) << "Cluster is set but device '" << engine.device
<< "' is not found in the cluster";
}
- } else { // cluster not found, possibly a python call
- VLOG(1) << "Cluster is not set, probably called from python";
- int found_device = 0;
- bool try_gpu_ids = true;
- // if device is set, try to find the device. Might be a problem for multi
- // host case but TensorRT do not support multi host setups yet.
- if (!engine.device.empty()) {
- DeviceNameUtils::ParsedName parsed_name;
- if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) {
- cuda_device_id = parsed_name.has_id ? parsed_name.id : -1;
- }
- try_gpu_ids = !parsed_name.has_id;
- }
- if (try_gpu_ids) {
- while (found_device < 100) {
- cuda_device_id = check_device_id(found_device);
- if (cuda_device_id >= 0) break;
- found_device++;
- }
- }
- if (found_device == 100) {
- LOG(ERROR) << " Can't find a GPU device to work with. Please "
- "instantiate a session to initialize devices";
- return std::make_pair(cuda_device_id, dev_allocator);
- }
- LOG(WARNING)
- << "Can't determine the device, constructing an allocator at device "
- << found_device;
- tensorflow::GPUOptions gpuoptions;
- // this will be a noop if device is already initialized
- gpuoptions.set_allow_growth(true);
- tensorflow::TfGpuId tf_gpu_id(found_device);
- dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
}
return std::make_pair(cuda_device_id, dev_allocator);
}
@@ -824,6 +891,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
}
VLOG(1) << "Current cuda device is " << old_cuda_device;
+ std::vector<Node*> engine_nodes;
+ engine_nodes.resize(engine_segments.size());
for (int i = 0; i < engine_segments.size(); ++i) {
auto& engine = engine_segments.at(i);
// Partition the workspace size by the average of node ratio and segment
@@ -847,19 +916,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
}
cudaSetDevice(cuda_device_id);
- auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(),
- params.max_batch_size);
+ auto status = CreateTRTNode(engine_segments, i, params.max_batch_size,
+ &graph, alloc.get(), &engine_nodes);
// If status is ok, we successfully added the node to the graph and can
// remove segment ops. Otherwise graph is not modified.
+ const string msg = StrCat("Engine ", engine.engine_name,
+ " creation for segment ", i, ", composed of ",
+ converted_segments.at(i).first.size(), " nodes");
if (status.ok()) {
+ LOG(INFO) << msg << " succeeded.";
for (auto node_name : converted_segments.at(i).first) {
graph.RemoveNode(node_map.at(node_name));
}
} else {
// Graph is not modified.
- LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
- << converted_segments.at(i).first.size()
- << " nodes failed: " << status << ". Skipping...";
+ LOG(WARNING) << msg << " failed: " << status << ". Skipping...";
}
}
cudaSetDevice(old_cuda_device);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 451d6fe698..35fa590254 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -2690,7 +2691,7 @@ tensorflow::Status ConvertGraphDefToEngine(
// Graph nodes are already topologically sorted during construction
for (const auto& node_def : gdef.node()) {
string node_name = node_def.name();
- VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
+ VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
(node_def.op() == "Placeholder")) {
nvinfer1::DimsCHW input_dim_pseudo_chw;
@@ -2788,6 +2789,7 @@ tensorflow::Status ConvertGraphDefToEngine(
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::set<string>& subgraph_node_names,
const std::vector<int>& subgraph_node_ids, // In topological order
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope) {
@@ -2796,6 +2798,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
// nodes in the segment graphdef.
for (size_t i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
+ if (connection.is_control_edge()) continue;
auto outside_node = graph->FindNodeId(connection.outside_id);
if (!outside_node) {
// This should never happen, unless the original graph is problematic.
@@ -2809,13 +2812,13 @@ tensorflow::Status ConvertSegmentToGraphDef(
GetInputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
-
+ connection.outside_shape = partial_shape;
} else {
GetOutputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
+ connection.inside_shape = partial_shape;
}
- connection.outside_shape = partial_shape;
connection.connection_type = dtype;
// Add dummy input/output nodes to the segment graphdef.
@@ -2868,12 +2871,12 @@ tensorflow::Status ConvertSegmentToGraphDef(
old_to_new_id_map[node_id] = segment_def->node_size();
auto snode = segment_def->add_node();
snode->CopyFrom(node->def());
- VLOG(1) << "Copying " << snode->name() << " to subgraph";
+ VLOG(2) << "Copying " << snode->name() << " to subgraph";
}
// Update the inputs of the new input nodes to point to placeholder nodes.
for (int i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
- if (!connection.is_input_edge) continue;
+ if (connection.is_control_edge() || !connection.is_input_edge) continue;
auto snode =
segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
const string placeholder_name =
@@ -2883,6 +2886,39 @@ tensorflow::Status ConvertSegmentToGraphDef(
<< placeholder_name;
snode->set_input(connection.inside_port, placeholder_name);
}
+ // Remove control inputs that are not inside the segment.
+ for (int i = 0; i < segment_def->node_size(); ++i) {
+ auto snode = segment_def->mutable_node(i);
+ const int input_size = snode->input_size();
+ int input_idx = 0;
+ int actual_input_idx = 0;
+ while (input_idx < input_size) {
+ TensorId input = ParseTensorName(snode->input(input_idx));
+ if (!subgraph_node_names.count(
+ string(input.first.data(), input.first.size())) &&
+ !str_util::StartsWith(input.first, kInputPHName)) {
+ if (input.second == Graph::kControlSlot) {
+ VLOG(1) << "... removing control inputs " << input.first
+ << " from subgraph.";
+ ++input_idx;
+ continue;
+ } else {
+ return tensorflow::errors::InvalidArgument(
+ "Found non control input outside the segment that is not an "
+ "engine connection to ",
+ snode->name(), ": ", input.first);
+ }
+ }
+ if (actual_input_idx != input_idx) {
+ snode->set_input(actual_input_idx, snode->input(input_idx));
+ }
+ ++input_idx;
+ ++actual_input_idx;
+ }
+ for (int remove = input_size - actual_input_idx; remove > 0; --remove) {
+ snode->mutable_input()->RemoveLast();
+ }
+ }
*common_scope = local_scope;
VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
return tensorflow::Status::OK();
@@ -2897,12 +2933,12 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
nvinfer1::DataType trt_dtype;
Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
if (!status.ok()) {
- VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
<< ": " << status;
return false;
}
if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
- VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
+ VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name()
<< " which has an input at port " << in_edge->dst_input()
<< " with #dim<3 and is not a const: " << shape;
return false;
@@ -2913,7 +2949,7 @@ bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
if (out_edge->IsControlEdge()) return true;
if (out_edge->src()->type_string() == "Const") {
- VLOG(2) << "--> Need to remove output node " << out_edge->src()->name()
+ VLOG(1) << "--> Need to remove output node " << out_edge->src()->name()
<< " which is a Const.";
return false;
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 6ae60ec352..a60253740f 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -36,16 +36,12 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-static const char* kInputPHName = "InputPH_";
-static const char* kOutputPHName = "OutputPH_";
+static const char* kInputPHName = "TensorRTInputPH_";
+static const char* kOutputPHName = "TensorRTOutputPH_";
namespace convert {
-// TODO(aaroey): use an enum instead.
-const int FP32MODE = 0;
-const int FP16MODE = 1;
-const int INT8MODE = 2;
-
struct EngineConnection {
+ // Constructs a non-control edge.
EngineConnection(const string& outside, int out_id, int out_port,
const string& inside, int in_id, int in_port,
bool input_edge, int port)
@@ -58,21 +54,35 @@ struct EngineConnection {
is_input_edge(input_edge),
port_number(port) {}
+ // Constructs a control edge.
+ EngineConnection(const string& outside, int out_id, const string& inside,
+ int in_id, bool input_edge)
+ : outside_node_name(outside),
+ outside_id(out_id),
+ outside_port(Graph::kControlSlot),
+ inside_node_name(inside),
+ inside_id(in_id),
+ inside_port(Graph::kControlSlot),
+ is_input_edge(input_edge),
+ port_number(Graph::kControlSlot) {}
+
+ bool is_control_edge() const { return port_number == Graph::kControlSlot; }
+
const string outside_node_name;
const int outside_id;
const int outside_port;
- tensorflow::PartialTensorShape outside_shape;
+ tensorflow::PartialTensorShape outside_shape; // Only set for input edge.
const string inside_node_name;
const int inside_id;
const int inside_port;
- tensorflow::PartialTensorShape inside_shape;
+ tensorflow::PartialTensorShape inside_shape; // Only set for output edge.
tensorflow::DataType connection_type;
- bool is_input_edge;
+ const bool is_input_edge;
- // The port number of the TRT node connecting to this edge.
- int port_number;
+ // The port number of the TRT node connected with this edge.
+ const int port_number;
};
struct EngineInfo {
@@ -85,7 +95,9 @@ struct EngineInfo {
string device;
tensorflow::GraphDef segment_graph_def;
- // The segment nodes that are on one side of the edges are topological sorted.
+ // Non-control input connections inside this vector are sorted in a way such
+ // that, the segment nodes connecting to them are topological sorted.
+ // In addition, for non-control connections, there must be no duplicates.
std::vector<EngineConnection> connections;
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
@@ -101,6 +113,7 @@ struct EngineInfo {
// (OutputPH_*). This function needs to be called before TensorRT nodes
// inserted in order to correctly get sizes from the original graph.
//
+// - subgraph_node_names: the node names of the subgraph.
// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
// topological order.
// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
@@ -110,6 +123,7 @@ struct EngineInfo {
tensorflow::Status ConvertSegmentToGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::set<string>& subgraph_node_names,
const std::vector<int>& subgraph_node_ids,
std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope);
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index 044c736c03..f33f2cc4d6 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stacktrace.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -189,9 +190,6 @@ tensorflow::Status TRTOptimizationPass::Optimize(
tensorflow::grappler::Cluster* cluster,
const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Called TRTOptimization Pass " << name_;
- if (VLOG_IS_ON(1)) {
- PrintDebugInfo(cluster, item);
- }
// This is a hack to workaround optimizer issue. MetaOptimizer calls
// optimization passes on function objects as well, we should not modify
// generated funcdefs! This is fragile but we don't have any other option
@@ -203,6 +201,10 @@ tensorflow::Status TRTOptimizationPass::Optimize(
*optimized_graph = item.graph;
return tensorflow::Status::OK();
}
+ if (VLOG_IS_ON(1)) {
+ VLOG(2) << CurrentStackTrace();
+ PrintDebugInfo(cluster, item);
+ }
int max_dim = -1;
if (item.feed.size()) {
for (const auto& f : item.feed) {
diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc
index 17857cf4d0..e7a1febb8c 100644
--- a/tensorflow/contrib/tensorrt/convert/utils.cc
+++ b/tensorflow/contrib/tensorrt/convert/utils.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
namespace tensorrt {
@@ -31,5 +34,36 @@ bool IsGoogleTensorRTEnabled() {
#endif
}
+Status GetPrecisionModeName(const int precision_mode, string* name) {
+ switch (precision_mode) {
+ case FP32MODE:
+ *name = "FP32";
+ break;
+ case FP16MODE:
+ *name = "FP16";
+ break;
+ case INT8MODE:
+ *name = "INT8";
+ break;
+ default:
+ return tensorflow::errors::OutOfRange("Unknown precision mode");
+ }
+ return Status::OK();
+}
+
+Status GetPrecisionMode(const string& name, int* precision_mode) {
+ if (name == "FP32") {
+ *precision_mode = FP32MODE;
+ } else if (name == "FP16") {
+ *precision_mode = FP16MODE;
+ } else if (name == "INT8") {
+ *precision_mode = INT8MODE;
+ } else {
+ return tensorflow::errors::InvalidArgument("Invalid precision mode name: ",
+ name);
+ }
+ return Status::OK();
+}
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h
index 8b5f4d614a..0592f31462 100644
--- a/tensorflow/contrib/tensorrt/convert/utils.h
+++ b/tensorflow/contrib/tensorrt/convert/utils.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
namespace tensorrt {
@@ -33,6 +35,15 @@ using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
bool IsGoogleTensorRTEnabled();
+// TODO(aaroey): use an enum instead.
+const int FP32MODE = 0;
+const int FP16MODE = 1;
+const int INT8MODE = 2;
+
+Status GetPrecisionModeName(const int precision_mode, string* name);
+
+Status GetPrecisionMode(const string& name, int* precision_mode);
+
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 6699b71d28..2b42d81f47 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -122,15 +123,9 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
- if (precision_string == "FP32") {
- precision_mode_ = convert::FP32MODE;
- } else if (precision_string == "FP16") {
- precision_mode_ = convert::FP16MODE;
- } else if (precision_string == "INT8") {
- precision_mode_ = convert::INT8MODE;
- }
+ OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_));
calibration_mode_ =
- (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0);
+ (precision_mode_ == INT8MODE && calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
@@ -179,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
- [ctx, outputs, helper](const tensorflow::Status& s) {
+ [this, ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
@@ -189,6 +184,8 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
+ test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
+ "done");
delete outputs;
});
}
@@ -234,6 +231,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
->implementation()
->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
+ test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
@@ -258,7 +256,7 @@ int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) {
StrCat("Engine buffer is full. buffer limit=", max_cached_engines_,
", current entries=");
for (auto i : cached_engine_batches_) StrAppend(&msg, i, ",");
- StrAppend(&msg, "Requested batch=", num_batch);
+ StrAppend(&msg, " requested batch=", num_batch);
LOG(WARNING) << msg;
return -1;
}
@@ -276,7 +274,8 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
}
const int smallest_engine = GetEngineBatch(ctx);
if (smallest_engine < 0) {
- LOG(WARNING) << "Failed to get engine batch, running native segment";
+ LOG(WARNING) << "Failed to get engine batch, running native segment for "
+ << name();
ExecuteNativeSegment(ctx, helper);
return;
}
@@ -286,14 +285,15 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
- << " failed. Running native segment";
+ << " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(),
engine_ctx_pair.second.get());
if (retry) {
- LOG(WARNING) << "Failed to execute engine, retrying with native segment";
+ LOG(WARNING) << "Failed to execute engine, "
+ << "retrying with native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
@@ -412,6 +412,7 @@ bool TRTEngineOp::ExecuteTrtEngine(
LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
return kRetry;
}
+ test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
// Synchronization will be done by TF.
return !kRetry;
}
@@ -589,7 +590,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
- *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(),
+ *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(),
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*convert_successfully=*/nullptr);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 59b744e6d3..8fe0675891 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -35,7 +35,7 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-class TRTInt8Calibrator;
+struct TRTInt8Calibrator;
class TRTCalibrationResource;
class AsyncHelper;
// TODO(Sami): Remove this file?
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
index fe4fa166a1..7cdfe2b1a6 100644
--- a/tensorflow/contrib/tensorrt/python/__init__.py
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -20,7 +20,11 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values
from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value
+from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled
# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 2b67931661..4116f2fe30 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -20,26 +20,26 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
+from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
+from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
-from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
-from tensorflow.python.util import compat
-
+from tensorflow.python.training import saver
# pylint: enable=unused-import,line-too-long
-# TODO(skama): get outputs from session when implemented as c++
-# optimization pass
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
@@ -48,7 +48,7 @@ def create_inference_graph(input_graph_def,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[]):
+ cached_engine_batches=None):
"""Python wrapper for the TRT transformation.
Args:
@@ -87,8 +87,7 @@ def create_inference_graph(input_graph_def,
(".".join([str(x) for x in compiled_version]),
".".join([str(x) for x in loaded_version])) +
". Please make sure that correct version of TensorRT " +
- "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
- )
+ "is available in the system and added to ldconfig or LD_LIBRARY_PATH")
raise RuntimeError("Incompatible TensorRT library version")
for i in zip(loaded_version, compiled_version):
if i[0] != i[1]:
@@ -121,41 +120,42 @@ def create_inference_graph(input_graph_def,
to_bytes = py3bytes
to_string = py3string
- out_names = []
- for i in outputs:
- if isinstance(i, ops.Tensor):
- out_names.append(to_bytes(i.name))
- else:
- out_names.append(to_bytes(i))
-
- input_graph_def_str = input_graph_def.SerializeToString()
-
- # TODO(sami): Fix this when we can return status from C++ library
- # There is a problem with the TF internal library setup that doesn't
- # allow us to return a status object from C++. Thus we return a
- # pair or strings where first one is encoded status and the second
- # one is the transformed graphs protobuf string.
- out = trt_convert(input_graph_def_str, out_names, max_batch_size,
- max_workspace_size_bytes, mode, minimum_segment_size,
- is_dynamic_op, maximum_cached_engines,
- cached_engine_batches)
- status = to_string(out[0])
- output_graph_def_string = out[1]
- del input_graph_def_str # Save some memory
- if len(status) < 2:
- raise _impl.UnknownError(None, None, status)
- if status[:2] != "OK":
- msg = status.split(";")
- if len(msg) == 1:
- raise RuntimeError("Status message is malformed {}".format(status))
- # pylint: disable=protected-access
- raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
- int(msg[0]))
- # pylint: enable=protected-access
- output_graph_def = graph_pb2.GraphDef()
- output_graph_def.ParseFromString(output_graph_def_string)
- del output_graph_def_string # Save some memory
- return output_graph_def
+ # Create MetaGraphDef
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(input_graph_def, name="")
+ meta_graph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
+ if outputs:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ output_list.append(to_bytes(i.name))
+ else:
+ output_list.append(to_bytes(i))
+ meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+
+ # Create RewriterConfig.
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ optimizer = rewriter_cfg.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batches:
+ if not isinstance(cached_engine_batches, list):
+ raise TypeError("cached_engine_batches should be a list.")
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batches)
+
+ return tf_optimizer.OptimizeGraph(
+ rewriter_cfg, meta_graph, graph_id=b"tf_graph")
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 008fffc954..b43f1b190f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -414,10 +414,10 @@ tensorflow::Status SegmentGraph(
}
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
- VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
+ VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
- VLOG(2) << "... not a TRT candidate";
+ VLOG(3) << "... not a TRT candidate";
continue;
}
// Contract output edges to combine 'node' with output
@@ -426,22 +426,22 @@ tensorflow::Status SegmentGraph(
while (true) {
std::set<const SimpleEdge*> contract_edges;
for (const SimpleEdge* out_edge : node->out_edges()) {
- VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
+ VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
- VLOG(2) << "... ... Control Edge, Skipping";
+ VLOG(3) << "... ... Control Edge, Skipping";
continue;
}
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
- VLOG(2) << "... ... not a TRT candidate";
+ VLOG(3) << "... ... not a TRT candidate";
continue;
}
if (CanContractEdge(out_edge, graph)) {
- VLOG(2) << "... ... can contract";
+ VLOG(3) << "... ... can contract";
contract_edges.insert(out_edge);
} else {
- VLOG(2) << "... ... cannot contract, would form cycle";
+ VLOG(3) << "... ... cannot contract, would form cycle";
}
}
if (contract_edges.empty()) {
@@ -454,7 +454,7 @@ tensorflow::Status SegmentGraph(
const SimpleNode* src = contract_edge->src();
const SimpleNode* dst = contract_edge->dst();
- VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
+ VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
@@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph(
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
- std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map;
+ std::map<string, std::set<const tensorflow::Node*>> sg_map;
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the device names that the nodes in the segment are
@@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph(
// then after doing this operation the resulting subgraph will keep the
// same properties 1 and 2.
//
- // For simplicity we use heuristics: for input nodes remove all its
- // input, for output nodes remove all its output. In this way, for common
- // cases the number of removed nodes should be minimum.
+ // For simplicity we use heuristics: for input and const output nodes
+ // remove all their inputs, and for non-const output nodes remove all
+ // their outputs. In this way, for common cases the number of removed
+ // nodes should be minimum.
auto remove_nodes = [&segment_nodes](
bool is_input_nodes,
std::deque<const tensorflow::Node*>* que) {
// Run a BFS on the queue to find all the input/output nodes.
std::set<const tensorflow::Node*> visited;
+ std::set<const tensorflow::Node*> logged(que->begin(), que->end());
while (!que->empty()) {
auto node = que->front();
que->pop_front();
if (!visited.insert(node).second) continue;
segment_nodes.erase(node);
- for (auto in :
- is_input_nodes ? node->in_nodes() : node->out_nodes()) {
+ for (auto in : (is_input_nodes || node->type_string() == "Const")
+ ? node->in_nodes()
+ : node->out_nodes()) {
if (segment_nodes.count(in)) {
que->push_back(in);
- VLOG(2) << "Need to remove node " << in->name()
- << " because one of its "
- << (is_input_nodes ? "output" : "input")
- << " nodes in the graph was removed: " << node->name();
+ if (VLOG_IS_ON(2)) {
+ if (!logged.count(in)) {
+ VLOG(2) << "----> Need to remove node " << in->name()
+ << " because one of its "
+ << (is_input_nodes ? "output" : "input")
+ << " nodes in the graph was removed: "
+ << node->name();
+ logged.insert(in);
+ }
+ }
}
}
}
@@ -594,7 +603,7 @@ tensorflow::Status SegmentGraph(
for (const auto& itr : sg_map) {
const std::set<const tensorflow::Node*>& segment_nodes = itr.second;
if (VLOG_IS_ON(1)) {
- string s;
+ string s = "parent=" + itr.first + ":";
for (auto node : segment_nodes) s += " " + node->name();
VLOG(1) << "Segment " << segments->size() << ": " << s;
}
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index 432e7b1c04..5937fa8259 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -206,7 +206,7 @@ TEST_F(SegmentTest, Multiple) {
// Make add5 not a TRT candidate, and we expect two segments.
auto without_add5 = all_adds - "add5";
RunTest(&g, without_add5, without_add5, without_add5,
- {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}});
+ {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
// Make add8 not a candidate and add6 not an input candidate, then all direct
// and indirect inputs of add6 will be removed from the segment.
@@ -252,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) {
const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
"add4", "add5", "add6", "add7"};
RunTest(&g, all_adds - "add2", all_adds, all_adds,
- {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}});
+ {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
}
} // namespace test
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index edd30ad7a9..8ea5a63735 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -20,17 +20,19 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
-class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing single segment."""
@@ -65,13 +67,17 @@ class SimpleSingleEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
+ # "relu", "identity", "max_pool"]
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(100, 6, 6, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
-class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
+class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
def GetParams(self):
"""Create a graph containing multiple segment."""
@@ -95,32 +101,246 @@ class SimpleMultiEngineGraphDefTest(trt_test.TfTrtIntegrationTestBase):
padding="SAME",
name="conv")
c1 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
- p = conv * c1
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1")
+ p = math_ops.mul(conv, c1, name="mul")
c2 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype)
- q = conv / c2
+ np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2")
+ q = math_ops.div(conv, c2, name="div")
- edge = self.trt_incompatible_op(q)
- edge /= edge
- r = edge + edge
+ edge = self.trt_incompatible_op(q, name="incompatible")
+ edge = math_ops.div(edge, edge, name="div1")
+ r = math_ops.add(edge, edge, name="add")
- p -= edge
- q *= edge
- s = p + q
- s -= r
+ p = math_ops.sub(p, edge, name="sub")
+ q = math_ops.mul(q, edge, name="mul1")
+ s = math_ops.add(p, q, name="add1")
+ s = math_ops.sub(s, r, name="sub1")
array_ops.squeeze(s, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
+ # breaks the connection check, fix it.
+ # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
+ # "add", "sub1"];
+ # - my_trt_op_1 should have ["weights","conv", "div"]
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(100, 12, 12, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
-# TODO(aaroey): add a large complex graph to test.
+class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
+
+ def setUp(self):
+ """Setup method."""
+ super(PartiallyConvertedTestA, self).setUp()
+ # Let it fail to build the second engine.
+ trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail")
+
+ def GetParams(self):
+ """Create a graph containing two segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ for i in range(2):
+ c = constant_op.constant(1.0, name="c%d" % i)
+ n = math_ops.add(n, c, name="add%d" % i)
+ n = math_ops.mul(n, n, name="mul%d" % i)
+ edge = self.trt_incompatible_op(n, name="incompatible")
+ with g.control_dependencies([edge]):
+ c = constant_op.constant(1.0, name="c2")
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul2")
+ c = constant_op.constant(1.0, name="c3")
+ n = math_ops.add(n, c, name="add3")
+ n = math_ops.mul(n, n, name="mul3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ # Only the first engine is built.
+ "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class PartiallyConvertedTestB(PartiallyConvertedTestA):
+
+ def setUp(self):
+ """Setup method."""
+ super(PartiallyConvertedTestB, self).setUp()
+ # Let it fail to build the first engine.
+ trt_convert.clear_test_values("")
+ trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
+
+ def GetParams(self):
+ """Create a graph containing two segment."""
+ return super(PartiallyConvertedTestB, self).GetParams()._replace(
+ expected_engines={
+ # Only the second engine is built.
+ "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ })
+
+
+class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ # Adds control dependency from the constant op to a trt incompatible op,
+ # and adds control dependency from the trt incompatible op to all other
+ # ops, to make sure the constant op cannot be contracted with any trt
+ # segment that depends on it.
+ with g.control_dependencies([c]):
+ d = self.trt_incompatible_op(n, name="incompatible")
+ with g.control_dependencies([d]):
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ n = self.trt_incompatible_op(n, name="incompatible1")
+ with g.control_dependencies([d]):
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul1")
+ n = math_ops.add(n, n, name="add3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["add", "add1", "mul"],
+ "my_trt_op_1": ["add2", "add3", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing single segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]},
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ n = self.trt_incompatible_op(n, name="incompatible1")
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul1")
+ n = math_ops.add(n, n, name="add3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ # Why segment ["add", "add1", "mul"] was assigned segment id 1
+ # instead of 0: the parent node of this segment is actually const
+ # node 'c', but it's removed later since it's const output of the
+ # segment which is not allowed.
+ "my_trt_op_1": ["add", "add1", "mul"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ c1 = constant_op.constant(1.0, name="c1")
+ c2 = constant_op.constant(1.0, name="c2")
+ d1 = constant_op.constant(1.0, name="d1")
+ d2 = self.trt_incompatible_op(inp, name="d2")
+ with g.control_dependencies([d1, d2]):
+ add = math_ops.add(inp, c1, name="add")
+ with g.control_dependencies([d1, d2]):
+ mul = math_ops.mul(add, add, name="mul")
+ with g.control_dependencies([d1, d2]):
+ add1 = math_ops.add(mul, mul, name="add1")
+ edge = self.trt_incompatible_op(add1, name="incompatible")
+ with g.control_dependencies([d1, d2, add, mul]):
+ add2 = math_ops.add(edge, c2, name="add2")
+ with g.control_dependencies([d1, d2, add1, mul]):
+ mul1 = math_ops.mul(add2, add2, name="mul1")
+ with g.control_dependencies([d1, d2, add, add1]):
+ add3 = math_ops.add(mul1, mul1, name="add3")
+ array_ops.squeeze(add3, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
index 730b6843fb..2e1107e303 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
@@ -66,7 +66,7 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name, w1_name, w2_name],
input_dims=[input_dims, w1_dims, w2_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(12, 5, 8, 7),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 0c03a10b64..8be32f59b4 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -102,7 +102,10 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=7,
+ expected_engines=[
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4", "my_trt_op_5", "my_trt_op_6"
+ ],
expected_output_dims=(48, 89),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
index dd673463a5..9316b14da0 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -109,7 +109,24 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=16,
+ expected_engines=[
+ "my_trt_op_0",
+ "my_trt_op_1",
+ "my_trt_op_2",
+ "my_trt_op_3",
+ "my_trt_op_4",
+ "my_trt_op_5",
+ "my_trt_op_6",
+ "my_trt_op_7",
+ "my_trt_op_8",
+ "my_trt_op_9",
+ "my_trt_op_10",
+ "my_trt_op_11",
+ "my_trt_op_12",
+ "my_trt_op_13",
+ "my_trt_op_14",
+ "my_trt_op_15",
+ ],
expected_output_dims=(5, 23040),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py
index 8c51c45b0a..1874b9dd45 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py
@@ -73,7 +73,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(2, 126),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
index 97b29bf05d..8c59000b70 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
@@ -58,7 +58,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=['my_trt_op_0'],
expected_output_dims=(5, 12, 12, 1),
allclose_atol=1.e-02,
allclose_rtol=1.e-02)
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
index 3dd95c6f62..66eb6be757 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -62,7 +62,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(2, 15, 15, 10),
allclose_atol=1.e-02,
allclose_rtol=1.e-02)
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
index 734ccf6345..fd55b8cd99 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -77,7 +77,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines=["my_trt_op_0", "my_trt_op_1"],
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 50265c0845..51c905a50b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
@@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
name="conv")
b = constant_op.constant(
np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
- t = conv * b
- e = gen_math_ops.tan(conv)
- t = t - e
+ t = math_ops.mul(conv, b, name="mul")
+ e = self.trt_incompatible_op(conv, name="incompatible")
+ t = math_ops.sub(t, e, name="sub")
array_ops.squeeze(t, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=2,
+ expected_engines={
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ },
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index bb7f5a77f0..6f85ada464 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from collections import namedtuple
import itertools
+import os
import warnings
import numpy as np
import six
@@ -30,6 +31,7 @@ from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -37,10 +39,14 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [
- "gdef", "input_names", "input_dims", "num_expected_engines",
+ "gdef", "input_names", "input_dims", "expected_engines",
"expected_output_dims", "allclose_atol", "allclose_rtol"
])
+RunParams = namedtuple(
+ "RunParams",
+ ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -48,6 +54,12 @@ def _IsQuantizationMode(mode):
return mode == "INT8"
+class GraphState(object):
+ ORIGINAL = 0
+ CALIBRATE = 1
+ INFERENCE = 2
+
+
class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
@@ -63,45 +75,90 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def precision_modes(self):
return ["FP32", "FP16", "INT8"]
+ # str is bytes in py2, but unicode in py3.
+ def _ToUnicode(self, s):
+ if six.PY2:
+ if isinstance(s, unicode):
+ return s
+ return s.decode("utf-8")
+ else:
+ if isinstance(s, str):
+ return s
+ return s.decode("utf-8")
+
def _ToBytes(self, s):
if six.PY2:
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
return s
else:
- return s.encode("utf-8")
+ if isinstance(s, str):
+ return s.encode("utf-8")
+ return s
def _ToString(self, s):
if six.PY2:
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
return s
else:
+ if isinstance(s, str):
+ return s
return s.decode("utf-8")
+ @classmethod
+ def setUpClass(cls):
+ """Setup method for the module."""
+ super(TfTrtIntegrationTestBase, cls).setUpClass()
+ trt_convert.enable_test_value()
+
def setUp(self):
"""Setup method."""
super(TfTrtIntegrationTestBase, self).setUp()
warnings.simplefilter("always")
+ trt_convert.clear_test_values("")
def GetParams(self):
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
raise NotImplementedError()
- def _GetConfigProto(self,
- params,
- use_optimizer,
- precision_mode=None,
- is_dynamic_op=None):
+ def _PrepareRun(self, params, graph_state):
+ """Set up necessary testing environment before calling sess.run()."""
+ # Clear test values added by TRTEngineOp.
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine")
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration")
+ trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment")
+
+ def _VerifyRun(self, params, graph_state):
+ """Verify the state after sess.run()."""
+ for engine_name in params.expected_engines:
+ if graph_state == GraphState.ORIGINAL:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.CALIBRATE:
+ self._ExpectCalibration(engine_name, "done")
+ self._ExpectNativeSegment(engine_name, "done")
+ self._ExpectTrtEngine(engine_name, "")
+ elif graph_state == GraphState.INFERENCE:
+ self._ExpectCalibration(engine_name, "")
+ self._ExpectNativeSegment(engine_name, "")
+ self._ExpectTrtEngine(engine_name, "done")
+
+ def _GetConfigProto(self, params, run_params, graph_state):
"""Get config proto based on specific settings."""
- if use_optimizer:
+ if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
rewriter_cfg = rewriter_config_pb2.RewriterConfig()
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 3
+ custom_op.parameter_map["minimum_segment_size"].i = 2
custom_op.parameter_map["max_batch_size"].i = max(
[dims[0] for dims in params.input_dims])
- custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
custom_op.parameter_map["precision_mode"].s = self._ToBytes(
- precision_mode)
+ run_params.precision_mode)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
graph_options = config_pb2.GraphOptions()
@@ -115,7 +172,26 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
gpu_options=gpu_options, graph_options=graph_options)
return config
- def _RunGraph(self, params, gdef, input_data, config, num_runs=2):
+ def _ExpectTestValue(self, engine_name, method, expected_value):
+ label = "%s:%s" % (engine_name, method)
+ actual_value = trt_convert.get_test_value(label)
+ self.assertEqual(
+ expected_value,
+ actual_value,
+ msg="Unexpected test value with label %s. Actual: %s; expected: %s" %
+ (label, actual_value, expected_value))
+
+ def _ExpectCalibration(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteCalibration", value)
+
+ def _ExpectTrtEngine(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value)
+
+ def _ExpectNativeSegment(self, engine_name, value):
+ self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value)
+
+ def _RunGraph(self, params, gdef, input_data, config, graph_state,
+ num_runs=2):
"""Run given graphdef multiple times."""
assert len(params.input_names) == len(input_data)
g = ops.Graph()
@@ -132,93 +208,170 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
val = None
# Defaults to 2 runs to verify result across multiple runs is same.
for _ in range(num_runs):
+ self._PrepareRun(params, graph_state)
new_val = sess.run(out,
{inp[i]: input_data[i] for i in range(len(inp))})
self.assertEqual(params.expected_output_dims, new_val.shape)
if val is not None:
self.assertAllEqual(val, new_val)
val = new_val
+ self._VerifyRun(params, graph_state)
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
def _RunCalibration(self, params, gdef, input_data, config):
"""Run calibration on given graph."""
- return self._RunGraph(params, gdef, input_data, config, 30)
+ return self._RunGraph(
+ params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5)
- def _GetTrtGraphDef(self, params, gdef, precision_mode, is_dynamic_op):
+ def _GetTrtGraphDef(self, params, run_params, gdef):
"""Return trt converted graphdef."""
return trt_convert.create_inference_graph(
input_graph_def=gdef,
outputs=[self.output_name],
max_batch_size=max([dims[0] for dims in params.input_dims]),
max_workspace_size_bytes=1 << 25,
- precision_mode=precision_mode,
+ precision_mode=run_params.precision_mode,
minimum_segment_size=2,
- is_dynamic_op=is_dynamic_op)
-
- def _VerifyGraphDef(self,
- params,
- gdef,
- precision_mode=None,
- is_calibrated=None,
- dynamic_engine=None):
+ is_dynamic_op=run_params.dynamic_engine)
+
+ def _WriteGraph(self, params, run_params, gdef, graph_state):
+ if graph_state == GraphState.ORIGINAL:
+ label = "Original"
+ elif graph_state == GraphState.CALIBRATE:
+ label = "CalibEngine"
+ elif graph_state == GraphState.INFERENCE:
+ label = "InferEngine"
+ graph_name = (
+ self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
+ ".pbtxt")
+ temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
+ logging.info("Writing graph to %s/%s", temp_dir, graph_name)
+ graph_io.write_graph(gdef, temp_dir, graph_name)
+
+ def _VerifyConnections(self, params, converted_gdef):
+ old_to_new_node_map = {
+ self._ToString(node.name): self._ToString(node.name)
+ for node in params.gdef.node
+ }
+ for engine_name, node_names in params.expected_engines.items():
+ for node_name in node_names:
+ old_to_new_node_map[node_name] = engine_name
+ name_to_node_map = {
+ self._ToString(node.name): node for node in params.gdef.node
+ }
+
+ def _InputName(inp):
+ inp = self._ToString(inp)
+ prefix = ""
+ if inp[0] == "^":
+ prefix = "^"
+ inp = inp[1:]
+ parts = inp.split(":")
+ if len(parts) > 1 and parts[-1].isdigit():
+ inp = inp[:-len(parts[-1]) - 1]
+ return (prefix, inp)
+
+ expected_input_map = {}
+ for node in params.gdef.node:
+ name_str = self._ToString(node.name)
+ target_node_name = old_to_new_node_map[name_str]
+ is_engine_op = (target_node_name != name_str)
+ if target_node_name not in expected_input_map:
+ expected_input_map[target_node_name] = set()
+ input_set = expected_input_map[target_node_name]
+ for inp in node.input:
+ (prefix, inp_name) = _InputName(inp)
+ # Add the input only if it's outside the segment (note that it could be
+ # in a different engine).
+ if (not is_engine_op or
+ old_to_new_node_map[inp_name] != target_node_name):
+ if is_engine_op and name_to_node_map[inp_name].op == "Const":
+ # Const data input nodes to the segment has been copied to the
+ # segment graphdef and the engine, and the dependency has been
+ # converted to control dependendy.
+ input_set.add("^" + old_to_new_node_map[inp_name])
+ else:
+ input_set.add(prefix + old_to_new_node_map[inp_name])
+
+ actual_input_map = {}
+ for node in converted_gdef.node:
+ name_str = self._ToString(node.name)
+ actual_input_map[name_str] = set()
+ input_set = actual_input_map[name_str]
+ for inp in node.input:
+ (prefix, node_name) = _InputName(inp)
+ input_set.add(prefix + node_name)
+
+ self.assertEqual(
+ expected_input_map,
+ actual_input_map,
+ msg="expected:\n%s\nvs actual:\n%s" % (sorted(
+ expected_input_map.items()), sorted(actual_input_map.items())))
+
+ def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
+ self._WriteGraph(params, run_params, gdef, graph_state)
+
num_engines = 0
- for n in gdef.node:
- # TODO(jie): we should have coverage for failed conversion (TF fallback).
- # where the conversion will fail and we shouldn't count this engine as the
- # converted engines.
- if n.op == "TRTEngineOp":
+ for node in gdef.node:
+ if node.op == "TRTEngineOp":
num_engines += 1
- self.assertNotEqual(self._ToBytes(""), n.attr["serialized_segment"].s)
- self.assertNotEqual(self._ToBytes(""), n.attr["segment_funcdef_name"].s)
+ self.assertTrue(node.name in params.expected_engines)
+ self.assertTrue(len(node.attr["serialized_segment"].s))
+ self.assertTrue(len(node.attr["segment_funcdef_name"].s))
self.assertEqual(
- self._ToBytes(precision_mode), n.attr["precision_mode"].s)
- self.assertEqual(not dynamic_engine, n.attr["static_engine"].b)
- if _IsQuantizationMode(precision_mode) and is_calibrated:
- self.assertNotEqual(self._ToBytes(""), n.attr["calibration_data"].s)
+ self._ToBytes(run_params.precision_mode),
+ node.attr["precision_mode"].s)
+
+ is_dynamic_engine = not node.attr["static_engine"].b
+ self.assertEqual(run_params.dynamic_engine, is_dynamic_engine)
+
+ has_calibration_data = len(node.attr["calibration_data"].s)
+ if (_IsQuantizationMode(run_params.precision_mode) and
+ graph_state == GraphState.INFERENCE):
+ self.assertTrue(has_calibration_data)
else:
- self.assertEqual(self._ToBytes(""), n.attr["calibration_data"].s)
- if precision_mode is None: # This means gdef is the original GraphDef.
+ self.assertFalse(has_calibration_data)
+ if graph_state == GraphState.ORIGINAL:
self.assertEqual(0, num_engines)
else:
- self.assertEqual(num_engines, params.num_expected_engines)
+ self.assertEqual(num_engines, len(params.expected_engines))
+ if isinstance(params.expected_engines, dict):
+ self._VerifyConnections(params, gdef)
+ # TODO(aaroey): consider verifying the corresponding TF function.
- def RunTest(self, params, use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine):
- assert precision_mode in PRECISION_MODES
+ def RunTest(self, params, run_params):
+ assert run_params.precision_mode in PRECISION_MODES
input_data = [np.random.random_sample(dims) for dims in params.input_dims]
input_gdef = params.gdef
- self._VerifyGraphDef(params, input_gdef)
+ self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL)
# Get reference result without running trt.
- config_no_trt = self._GetConfigProto(params, False)
+ config_no_trt = self._GetConfigProto(params, run_params,
+ GraphState.ORIGINAL)
logging.info("Running original graph w/o trt, config:\n%s",
str(config_no_trt))
- ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt)
+ ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt,
+ GraphState.ORIGINAL)
# Run calibration if necessary.
- if _IsQuantizationMode(precision_mode):
+ if _IsQuantizationMode(run_params.precision_mode):
- calib_config = self._GetConfigProto(params, use_optimizer, precision_mode,
- dynamic_calib_engine)
+ calib_config = self._GetConfigProto(params, run_params,
+ GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
- if use_optimizer:
- self.assertTrue(False)
- # TODO(aaroey): uncomment this and get infer_gdef when this mode is
- # supported.
- # result = self._RunCalibration(params, input_gdef, input_data,
- # calib_config)
+ if run_params.use_optimizer:
+ result = self._RunCalibration(params, input_gdef, input_data,
+ calib_config)
else:
- calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode,
- dynamic_calib_engine)
- self._VerifyGraphDef(params, calib_gdef, precision_mode, False,
- dynamic_calib_engine)
+ calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef)
+ self._VerifyGraphDef(params, run_params, calib_gdef,
+ GraphState.CALIBRATE)
result = self._RunCalibration(params, calib_gdef, input_data,
calib_config)
- infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
- self._VerifyGraphDef(params, infer_gdef, precision_mode, True,
- dynamic_calib_engine)
+ infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
+ self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE)
self.assertAllClose(
ref_result,
@@ -229,18 +382,19 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
infer_gdef = input_gdef
# Run inference.
- infer_config = self._GetConfigProto(params, use_optimizer, precision_mode,
- dynamic_infer_engine)
+ infer_config = self._GetConfigProto(params, run_params,
+ GraphState.INFERENCE)
logging.info("Running final inference graph, config:\n%s",
str(infer_config))
- if use_optimizer:
- result = self._RunGraph(params, infer_gdef, input_data, infer_config)
+ if run_params.use_optimizer:
+ result = self._RunGraph(params, infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
else:
- trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode,
- dynamic_infer_engine)
- self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True,
- dynamic_infer_engine)
- result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config)
+ trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef)
+ self._VerifyGraphDef(params, run_params, trt_infer_gdef,
+ GraphState.INFERENCE)
+ result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config,
+ GraphState.INFERENCE)
self.assertAllClose(
ref_result,
@@ -263,66 +417,44 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _AddTests(test_class):
"""Adds test methods to TfTrtIntegrationTestBase."""
- def _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine):
+ def _GetTest(run_params):
"""Gets a single test method based on the parameters."""
def _Test(self):
params = self.GetParams()
logging.info(
- "Running test with parameters: use_optimizer=%s, precision_mode=%s, "
- "dynamic_infer_engine=%s, dynamic_calib_engine=%s", use_optimizer,
- precision_mode, dynamic_infer_engine, dynamic_calib_engine)
- self.RunTest(params, use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine)
+ "Running test %s with parameters: use_optimizer=%s, "
+ "precision_mode=%s, dynamic_engine=%s",
+ "testTfTrt_" + run_params.test_name, run_params.use_optimizer,
+ run_params.precision_mode, run_params.dynamic_engine)
+ self.RunTest(params, run_params)
return _Test
use_optimizer_options = [False, True]
- dynamic_infer_engine_options = [False, True]
- dynamic_calib_engine_options = [False, True]
- for (use_optimizer, precision_mode,
- dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
- use_optimizer_options, PRECISION_MODES, dynamic_infer_engine_options,
- dynamic_calib_engine_options):
+ dynamic_engine_options = [False, True]
+ for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
+ use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
if _IsQuantizationMode(precision_mode):
- if not dynamic_calib_engine and dynamic_infer_engine:
- # TODO(aaroey): test this case, the conversion from static calibration
- # engine to dynamic inference engine should be a noop.
- continue
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
# supported yet.
continue
- if not dynamic_calib_engine:
+ if not dynamic_engine:
# TODO(aaroey): construction of static calibration engine is not
# supported yet.
continue
- if dynamic_calib_engine and not dynamic_infer_engine:
- # TODO(aaroey): construction of static inference engine using dynamic
- # calibration engine is not supported yet.
- continue
- else: # In non int8 mode.
- if dynamic_calib_engine:
- # dynamic_calib_engine doesn't affect non-int8 modes, so just let
- # related tests run once on dynamic_calib_engine=False.
- continue
conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
- infer_engine_type = ("DynamicInferEngine"
- if dynamic_infer_engine else "StaticInferEngine")
- calib_engine_type = ""
- if precision_mode == "INT8":
- calib_engine_type = ("DynamicCalibEngine"
- if dynamic_calib_engine else "StaticCalibEngine")
- test_name = "%s_%s_%s%s" % (conversion, precision_mode, infer_engine_type,
- ("_" + calib_engine_type)
- if len(calib_engine_type) else "")
- setattr(
- test_class, "testTfTRT_" + test_name,
- _GetTest(use_optimizer, precision_mode, dynamic_infer_engine,
- dynamic_calib_engine))
+ engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine")
+ test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type)
+ run_params = RunParams(
+ use_optimizer=use_optimizer,
+ precision_mode=precision_mode,
+ dynamic_engine=dynamic_engine,
+ test_name=test_name)
+ setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
if trt_convert.is_tensorrt_enabled():
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py
index b9e977cf67..500057a36d 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/contrib/tensorrt/test/unary_test.py
@@ -100,7 +100,10 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name, input2_name],
input_dims=[input_dims, input2_dims],
- num_expected_engines=5,
+ expected_engines=[
+ "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3",
+ "my_trt_op_4"
+ ],
expected_output_dims=(12, 5, 8, 12),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc
new file mode 100644
index 0000000000..276308b3a0
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.cc
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/test/utils.h"
+
+#include <unordered_map>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// TODO(aaroey): make this class thread-safe.
+class TestValueManager {
+ public:
+ static TestValueManager* singleton() {
+ static TestValueManager* manager = new TestValueManager();
+ return manager;
+ }
+
+ void Enable() {
+ VLOG(1) << "Enabling test value";
+ enabled_ = true;
+ }
+
+ void Add(const string& label, const string& value) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ QCHECK_NE("", value);
+ VLOG(1) << "Adding test value: " << label << " -> " << value;
+ values_.insert({label, value});
+ }
+ }
+
+ string Get(const string& label) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ VLOG(1) << "Getting test value by " << label;
+ auto itr = values_.find(label);
+ if (itr == values_.end()) return "";
+ return itr->second;
+ }
+ return "";
+ }
+
+ void Clear(const string& pattern) {
+ if (TF_PREDICT_FALSE(enabled_)) {
+ VLOG(1) << "Clearing test values";
+ if (pattern.empty()) {
+ values_.clear();
+ return;
+ }
+ std::vector<string> keys_to_clear;
+ for (const auto& kv : values_) {
+ if (RE2::FullMatch(kv.first, pattern)) {
+ keys_to_clear.push_back(kv.first);
+ }
+ }
+ for (const string& key : keys_to_clear) {
+ values_.erase(key);
+ }
+ }
+ }
+
+ private:
+ TestValueManager() : enabled_(false) {}
+
+ bool enabled_;
+ std::unordered_map<string, string> values_;
+};
+
+void EnableTestValue() { TestValueManager::singleton()->Enable(); }
+
+void ClearTestValues(const string& pattern) {
+ TestValueManager::singleton()->Clear(pattern);
+}
+
+void AddTestValue(const string& label, const string& value) {
+ TestValueManager::singleton()->Add(label, value);
+}
+
+string GetTestValue(const string& label) {
+ return TestValueManager::singleton()->Get(label);
+}
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h
new file mode 100644
index 0000000000..4bb4120206
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/utils.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace test {
+
+// Helper methods to inject values used by testing tools.
+void EnableTestValue();
+void ClearTestValues(const string& pattern);
+void AddTestValue(const string& label, const string& value);
+string GetTestValue(const string& label);
+
+#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \
+ do { \
+ if (::tensorflow::tensorrt::test::GetTestValue(label) == \
+ value_to_return) { \
+ return errors::Internal("Injected manually"); \
+ } \
+ } while (0)
+
+} // namespace test
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
index 2b134c3bce..ab4d224db4 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
@@ -72,7 +72,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(5, 6, 2, 2),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
index bec2f23eff..56bdf848ea 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py
@@ -63,7 +63,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(5, 2, 2, 6),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 422740fdf6..6ea15fb8ef 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -101,82 +101,22 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/test/utils.h"
%}
%ignoreall
%unignore tensorflow;
-%unignore trt_convert;
%unignore calib_convert;
%unignore get_linked_tensorrt_version;
%unignore get_loaded_tensorrt_version;
%unignore is_tensorrt_enabled;
+%unignore enable_test_value;
+%unignore clear_test_values;
+%unignore add_test_value;
+%unignore get_test_value;
%{
-std::pair<string, string> trt_convert(
- string graph_def_string, // The serialized GraphDef string.
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size_bytes,
- int precision_mode,
- int minimum_segment_size,
- bool is_dyn_op,
- int max_cached_engines,
- std::vector<int> cached_engine_batches
- // Unfortunately we can't use TF_Status here since it
- // is in c/c_api and brings in a lot of other libraries
- // which in turn declare ops. These ops are included
- // statically in our library and cause an abort when
- // module is loaded due to double registration
- // until Tensorflow properly exposes these headers
- // we have to work around this by returning a string
- // and converting it to exception on python side.
- //,TF_Status* out_status) {
-) {
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
- string out_status;
-
- tensorflow::GraphDef graph_def;
- if (!graph_def.ParseFromString(graph_def_string)) {
- out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
- return std::pair<string, string>{out_status, ""};
- }
-
- if (precision_mode < 0 || precision_mode > 2) {
- out_status = "InvalidArgument;Invalid precision_mode";
- return std::pair<string, string>{out_status, ""};
- }
- if (!output_names.size()) {
- out_status = "InvalidArgument;Size of the output_names vector is 0";
- return std::pair<string, string>{out_status, ""};
- }
- tensorflow::GraphDef out_graph;
- tensorflow::Status conversion_status =
- tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
- graph_def, output_names, max_batch_size, max_workspace_size_bytes,
- &out_graph, precision_mode, minimum_segment_size,
- is_dyn_op, max_cached_engines, cached_engine_batches);
- if (!conversion_status.ok()) {
- auto retCode = (int)conversion_status.code();
- char buff[2000];
- snprintf(buff, 2000, "%d;%s", retCode,
- conversion_status.error_message().c_str());
- out_status = buff;
- return std::pair<string, string>{out_status, ""};
- }
- string result;
- if (!out_graph.SerializeToString(&result)) {
- out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
- return std::pair<string, string>{out_status, ""};
- }
- out_status = "OK;All good!";
- return std::pair<string, string>{out_status, result};
-#else
- // Returns FAILED_PRECONDITION.
- return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
-#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
-}
-
std::pair<string, string> calib_convert(
string graph_def_string, bool is_dyn_op
// unfortunately we can't use TF_Status here since it
@@ -251,20 +191,44 @@ bool is_tensorrt_enabled() {
return tensorflow::tensorrt::IsGoogleTensorRTEnabled();
}
-%}
+void enable_test_value() {
+ tensorflow::tensorrt::test::EnableTestValue();
+}
+
+#if PY_MAJOR_VERSION < 3
+#define TRT_PY_TO_CPP_STRING PyString_AsString
+#define TRT_CPP_TO_PY_STRING PyString_FromString
+#else
+#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8
+#define TRT_CPP_TO_PY_STRING PyUnicode_FromString
+#endif
+
+void clear_test_values(PyObject* pattern) {
+ tensorflow::tensorrt::test::ClearTestValues(
+ string(TRT_PY_TO_CPP_STRING(pattern)));
+}
+
+void add_test_value(PyObject* label, PyObject* value) {
+ tensorflow::tensorrt::test::AddTestValue(
+ string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value)));
+}
-std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
+PyObject* get_test_value(PyObject* label) {
+ string value = tensorflow::tensorrt::test::GetTestValue(
+ string(TRT_PY_TO_CPP_STRING(label)));
+ return TRT_CPP_TO_PY_STRING(value.c_str());
+}
-std::pair<string, string> trt_convert(string graph_def_string,
- std::vector<string> output_names,
- size_t max_batch_size,
- size_t max_workspace_size_bytes,
- int precision_mode, int minimum_segment_size,
- bool is_dyn_op,
- int max_cached_engines,
- std::vector<int> cached_engine_batches);
+%}
+
+std::pair<string, string> calib_convert(
+ string graph_def_string, bool is_dyn_op);
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();
bool is_tensorrt_enabled();
+void enable_test_value();
+void clear_test_values(PyObject* pattern);
+void add_test_value(PyObject* label, PyObject* value);
+PyObject* get_test_value(PyObject* label);
%unignoreall
diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py
index 11db56b1b7..654a4db098 100644
--- a/tensorflow/contrib/timeseries/__init__.py
+++ b/tensorflow/contrib/timeseries/__init__.py
@@ -27,6 +27,9 @@
@@TrainEvalFeatures
@@FilteringResults
+
+@@TimeSeriesRegressor
+@@OneShotPredictionHead
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 7020989d68..0e96c1fbd4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -161,6 +161,7 @@ py_test(
srcs = [
"head_test.py",
],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["no_pip_gpu"], # b/63391119
deps = [
diff --git a/tensorflow/contrib/timeseries/python/timeseries/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/__init__.py
index c683dad71d..8462138339 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/__init__.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/__init__.py
@@ -24,5 +24,6 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
from tensorflow.contrib.timeseries.python.timeseries.ar_model import *
from tensorflow.contrib.timeseries.python.timeseries.estimators import *
from tensorflow.contrib.timeseries.python.timeseries.feature_keys import *
+from tensorflow.contrib.timeseries.python.timeseries.head import *
from tensorflow.contrib.timeseries.python.timeseries.input_pipeline import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 769183f40a..0ddc4b4144 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.training import training as train
from tensorflow.python.util import nest
@@ -79,12 +80,137 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
model_dir=model_dir,
config=config)
- # TODO(allenl): A parsing input receiver function, which takes a serialized
- # tf.Example containing all features (times, values, any exogenous features)
- # and serialized model state (possibly also as a tf.Example).
- def build_raw_serving_input_receiver_fn(self,
- default_batch_size=None,
- default_series_length=None):
+ def _model_start_state_placeholders(
+ self, batch_size_tensor, static_batch_size=None):
+ """Creates placeholders with zeroed start state for the current model."""
+ gathered_state = {}
+ # Models may not know the shape of their state without creating some
+ # variables/ops. Avoid polluting the default graph by making a new one. We
+ # use only static metadata from the returned Tensors.
+ with ops.Graph().as_default():
+ self._model.initialize_graph()
+ # Evaluate the initial state as same-dtype "zero" values. These zero
+ # constants aren't used, but are necessary for feeding to
+ # placeholder_with_default for the "cold start" case where state is not
+ # fed to the model.
+ def _zeros_like_constant(tensor):
+ return tensor_util.constant_value(array_ops.zeros_like(tensor))
+ start_state = nest.map_structure(
+ _zeros_like_constant, self._model.get_start_state())
+ for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
+ start_state).items():
+ state_shape_with_batch = tensor_shape.TensorShape(
+ (static_batch_size,)).concatenate(state.shape)
+ default_state_broadcast = array_ops.tile(
+ state[None, ...],
+ multiples=array_ops.concat(
+ [batch_size_tensor[None],
+ array_ops.ones(len(state.shape), dtype=dtypes.int32)],
+ axis=0))
+ gathered_state[prefixed_state_name] = array_ops.placeholder_with_default(
+ input=default_state_broadcast,
+ name=prefixed_state_name,
+ shape=state_shape_with_batch)
+ return gathered_state
+
+ def build_one_shot_parsing_serving_input_receiver_fn(
+ self, filtering_length, prediction_length, default_batch_size=None,
+ values_input_dtype=None, truncate_values=False):
+ """Build an input_receiver_fn for export_savedmodel accepting tf.Examples.
+
+ Only compatible with `OneShotPredictionHead` (see `head`).
+
+ Args:
+ filtering_length: The number of time steps used as input to the model, for
+ which values are provided. If more than `filtering_length` values are
+ provided (via `truncate_values`), only the first `filtering_length`
+ values are used.
+ prediction_length: The number of time steps requested as predictions from
+ the model. Times and all exogenous features must be provided for these
+ steps.
+ default_batch_size: If specified, must be a scalar integer. Sets the batch
+ size in the static shape information of all feature Tensors, which means
+ only this batch size will be accepted by the exported model. If None
+ (default), static shape information for batch sizes is omitted.
+ values_input_dtype: An optional dtype specification for values in the
+ tf.Example protos (either float32 or int64, since these are the numeric
+ types supported by tf.Example). After parsing, values are cast to the
+ model's dtype (float32 or float64).
+ truncate_values: If True, expects `filtering_length + prediction_length`
+ values to be provided, but only uses the first `filtering_length`. If
+ False (default), exactly `filtering_length` values must be provided.
+
+ Returns:
+ An input_receiver_fn which may be passed to the Estimator's
+ export_savedmodel.
+
+ Expects features contained in a vector of serialized tf.Examples with
+ shape [batch size] (dtype `tf.string`), each tf.Example containing
+ features with the following shapes:
+ times: [filtering_length + prediction_length] integer
+ values: [filtering_length, num features] floating point. If
+ `truncate_values` is True, expects `filtering_length +
+ prediction_length` values but only uses the first `filtering_length`.
+ all exogenous features: [filtering_length + prediction_length, ...]
+ (various dtypes)
+ """
+ if values_input_dtype is None:
+ values_input_dtype = dtypes.float32
+ if truncate_values:
+ values_proto_length = filtering_length + prediction_length
+ else:
+ values_proto_length = filtering_length
+
+ def _serving_input_receiver_fn():
+ """A receiver function to be passed to export_savedmodel."""
+ times_column = feature_column.numeric_column(
+ key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64)
+ values_column = feature_column.numeric_column(
+ key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype,
+ shape=(self._model.num_features,))
+ parsed_features_no_sequence = (
+ feature_column.make_parse_example_spec(
+ list(self._model.exogenous_feature_columns)
+ + [times_column, values_column]))
+ parsed_features = {}
+ for key, feature_spec in parsed_features_no_sequence.items():
+ if isinstance(feature_spec, parsing_ops.FixedLenFeature):
+ if key == feature_keys.TrainEvalFeatures.VALUES:
+ parsed_features[key] = feature_spec._replace(
+ shape=((values_proto_length,)
+ + feature_spec.shape))
+ else:
+ parsed_features[key] = feature_spec._replace(
+ shape=((filtering_length + prediction_length,)
+ + feature_spec.shape))
+ elif feature_spec.dtype == dtypes.string:
+ parsed_features[key] = parsing_ops.FixedLenFeature(
+ shape=(filtering_length + prediction_length,),
+ dtype=dtypes.string)
+ else: # VarLenFeature
+ raise ValueError("VarLenFeatures not supported, got %s for key %s"
+ % (feature_spec, key))
+ tfexamples = array_ops.placeholder(
+ shape=[default_batch_size], dtype=dtypes.string, name="input")
+ features = parsing_ops.parse_example(
+ serialized=tfexamples,
+ features=parsed_features)
+ features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze(
+ features[feature_keys.TrainEvalFeatures.TIMES], axis=-1)
+ features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast(
+ features[feature_keys.TrainEvalFeatures.VALUES],
+ dtype=self._model.dtype)[:, :filtering_length]
+ features.update(
+ self._model_start_state_placeholders(
+ batch_size_tensor=array_ops.shape(
+ features[feature_keys.TrainEvalFeatures.TIMES])[0],
+ static_batch_size=default_batch_size))
+ return export_lib.ServingInputReceiver(
+ features, {"examples": tfexamples})
+ return _serving_input_receiver_fn
+
+ def build_raw_serving_input_receiver_fn(
+ self, default_batch_size=None, default_series_length=None):
"""Build an input_receiver_fn for export_savedmodel which accepts arrays.
Automatically creates placeholders for exogenous `FeatureColumn`s passed to
@@ -149,34 +275,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
+ batch_only_feature_shape[1:])
placeholders[feature_key] = array_ops.placeholder(
dtype=value_dtype, name=feature_key, shape=feature_shape)
- # Models may not know the shape of their state without creating some
- # variables/ops. Avoid polluting the default graph by making a new one. We
- # use only static metadata from the returned Tensors.
- with ops.Graph().as_default():
- self._model.initialize_graph()
- # Evaluate the initial state as same-dtype "zero" values. These zero
- # constants aren't used, but are necessary for feeding to
- # placeholder_with_default for the "cold start" case where state is not
- # fed to the model.
- def _zeros_like_constant(tensor):
- return tensor_util.constant_value(array_ops.zeros_like(tensor))
- start_state = nest.map_structure(
- _zeros_like_constant, self._model.get_start_state())
batch_size_tensor = array_ops.shape(time_placeholder)[0]
- for prefixed_state_name, state in ts_head_lib.state_to_dictionary(
- start_state).items():
- state_shape_with_batch = tensor_shape.TensorShape(
- (default_batch_size,)).concatenate(state.shape)
- default_state_broadcast = array_ops.tile(
- state[None, ...],
- multiples=array_ops.concat(
- [batch_size_tensor[None],
- array_ops.ones(len(state.shape), dtype=dtypes.int32)],
- axis=0))
- placeholders[prefixed_state_name] = array_ops.placeholder_with_default(
- input=default_state_broadcast,
- name=prefixed_state_name,
- shape=state_shape_with_batch)
+ placeholders.update(
+ self._model_start_state_placeholders(
+ batch_size_tensor, static_batch_size=default_batch_size))
return export_lib.ServingInputReceiver(placeholders, placeholders)
return _serving_input_receiver_fn
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index 8686a803e5..d2484d0ef5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -180,7 +181,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
return math_ops.cast(value, self.model.dtype)
if name == feature_keys.PredictionFeatures.STATE_TUPLE:
return value # Correct dtypes are model-dependent
- return ops.convert_to_tensor(value)
+ return sparse_tensor.convert_to_tensor_or_sparse_tensor(value)
def _gather_state(self, features):
"""Returns `features` with state packed, indicates if packing was done."""
@@ -202,6 +203,29 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
flat_sequence=[tensor for _, _, tensor in numbered_state])
return features, True
+ def _check_predict_features(self, features):
+ """Raises errors if features are not suitable for prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ # Model-dependent shapes
+ feature_keys.PredictionFeatures.STATE_TUPLE
+ ]))
+
def create_estimator_spec(self, features, mode, labels=None):
"""Performs basic error checking and returns an EstimatorSpec."""
with ops.name_scope(self._name, "head"):
@@ -230,7 +254,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
mode == estimator_lib.ModeKeys.EVAL):
_check_train_eval_features(features, self.model)
elif mode == estimator_lib.ModeKeys.PREDICT:
- _check_predict_features(features)
+ self._check_predict_features(features)
else:
raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
@@ -267,6 +291,36 @@ class OneShotPredictionHead(TimeSeriesRegressionHead):
each time predictions are requested when using this head.
"""
+ def _check_predict_features(self, features):
+ """Raises errors if features are not suitable for one-shot prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.TrainEvalFeatures.VALUES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.TrainEvalFeatures.VALUES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ # Model-dependent shapes
+ feature_keys.PredictionFeatures.STATE_TUPLE,
+ # One shot prediction head relies on values being shorter than
+ # times. Even though we're predicting eventually, we need values for
+ # the filtering phase.
+ feature_keys.TrainEvalFeatures.VALUES,
+ ]))
+
def _serving_ops(self, features):
"""Add ops for serving to the graph."""
with variable_scope.variable_scope("model", use_resource=True):
@@ -333,29 +387,6 @@ def _check_feature_shapes_compatible_with(features,
times_shape=compatible_with_value.get_shape()))
-def _check_predict_features(features):
- """Raises errors if features are not suitable for prediction."""
- if feature_keys.PredictionFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.TIMES))
- if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.STATE_TUPLE))
- times_feature = features[feature_keys.PredictionFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
- times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.PredictionFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
def _check_train_eval_features(features, model):
"""Raise errors if features are not suitable for training/evaluation."""
if feature_keys.TrainEvalFeatures.TIMES not in features:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 78c2cec21c..857e7c5635 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import os
from absl.testing import parameterized
@@ -26,12 +27,14 @@ import six
from tensorflow.contrib.estimator.python.estimator import extenders
from tensorflow.contrib.timeseries.examples import lstm as lstm_example
+from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
from tensorflow.contrib.timeseries.python.timeseries import input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import state_management
+from tensorflow.core.example import example_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.estimator import estimator_lib
@@ -343,15 +346,33 @@ def _structural_ensemble_regressor(
model_dir=model_dir)
+def _ar_lstm_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.TimeSeriesRegressor(
+ model=ar_model.ARModel(
+ periodicities=10, input_window_size=10, output_window_size=6,
+ num_features=5,
+ exogenous_feature_columns=exogenous_feature_columns,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel,
+ num_units=10)),
+ head_type=head_type,
+ model_dir=model_dir)
+
+
class OneShotTests(parameterized.TestCase):
@parameterized.named_parameters(
+ {"testcase_name": "ar_lstm_regressor",
+ "estimator_factory": _ar_lstm_regressor},
{"testcase_name": "custom_time_series_regressor",
"estimator_factory": _custom_time_series_regressor},
{"testcase_name": "structural_ensemble_regressor",
"estimator_factory": _structural_ensemble_regressor})
def test_one_shot_prediction_head_export(self, estimator_factory):
- model_dir = os.path.join(test.get_temp_dir(), str(ops.uid()))
+ def _new_temp_dir():
+ return os.path.join(test.get_temp_dir(), str(ops.uid()))
+ model_dir = _new_temp_dir()
categorical_column = feature_column.categorical_column_with_hash_bucket(
key="categorical_exogenous_feature", hash_bucket_size=16)
exogenous_feature_columns = [
@@ -377,7 +398,7 @@ class OneShotTests(parameterized.TestCase):
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
- export_location = estimator.export_savedmodel(test.get_temp_dir(),
+ export_location = estimator.export_savedmodel(_new_temp_dir(),
input_receiver_fn)
graph = ops.Graph()
with graph.as_default():
@@ -412,6 +433,41 @@ class OneShotTests(parameterized.TestCase):
in predict_signature.outputs.items()}
output = session.run(fetches, feed_dict=feeds)
self.assertEqual((2, 15, 5), output["mean"].shape)
+ # Build a parsing input function, then make a tf.Example for it to parse.
+ export_location = estimator.export_savedmodel(
+ _new_temp_dir(),
+ estimator.build_one_shot_parsing_serving_input_receiver_fn(
+ filtering_length=20, prediction_length=15))
+ graph = ops.Graph()
+ with graph.as_default():
+ with session_lib.Session() as session:
+ example = example_pb2.Example()
+ times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES]
+ values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES]
+ times.int64_list.value.extend(range(35))
+ for i in range(20):
+ values.float_list.value.extend(
+ [float(i) * 2. + feature_number
+ for feature_number in range(5)])
+ real_feature = example.features.feature["2d_exogenous_feature"]
+ categortical_feature = example.features.feature[
+ "categorical_exogenous_feature"]
+ for i in range(35):
+ real_feature.float_list.value.extend([1, 1])
+ categortical_feature.bytes_list.value.append(b"strkey")
+ # Serialize the tf.Example for feeding to the Session
+ examples = [example.SerializeToString()] * 2
+ signatures = loader.load(
+ session, [tag_constants.SERVING], export_location)
+ predict_signature = signatures.signature_def[
+ feature_keys.SavedModelLabels.PREDICT]
+ ((_, input_value),) = predict_signature.inputs.items()
+ feeds = {graph.as_graph_element(input_value.name): examples}
+ fetches = {output_key: graph.as_graph_element(output_value.name)
+ for output_key, output_value
+ in predict_signature.outputs.items()}
+ output = session.run(fetches, feed_dict=feeds)
+ self.assertEqual((2, 15, 5), output["mean"].shape)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index c5855106de..f5d852908a 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -47,7 +47,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":tpu_lib",
- ":tpu_py",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/contrib/training:training_py",
@@ -136,7 +135,7 @@ py_library(
tf_custom_op_py_library(
name = "tpu_py",
- srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
+ srcs = glob(["python/ops/*.py"]),
dso = [":python/ops/_tpu_ops.so"],
kernels = [
":all_ops",
@@ -155,9 +154,13 @@ tf_custom_op_py_library(
py_library(
name = "tpu",
- srcs = ["python/tpu/__init__.py"],
+ srcs = [
+ "__init__.py",
+ "python/tpu/__init__.py",
+ ],
srcs_version = "PY2AND3",
deps = [
+ ":keras_support", # split out to avoid cycle with tpu_strategy
":tpu_estimator",
":tpu_lib",
],
@@ -172,19 +175,13 @@ py_library(
visibility = [
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//learning/brain:__subpackages__",
- # TODO(b/111651964): Clean special visibility for keras_support.
- #
- # Note: If you are an end user, please do not add your project to this
- # visibility. This feature is experimental, and will be made public
- # when ready.
- "//third_party/cloud_tpu/models/keras:__subpackages__",
"//tensorflow:__subpackages__",
+ "//third_party/cloud_tpu/models/keras:__subpackages__",
],
deps = [
":tpu_lib",
- ":tpu_py",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
- "//tensorflow/contrib/distribute/python:tpu_strategy",
+ "//tensorflow/contrib/distribute",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index cac346ae30..d0a37eb0ed 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -47,6 +47,9 @@
@@InputPipelineConfig
@@TPUConfig
@@bfloat16_scope
+
+@@TPUDistributionStrategy
+@@keras_to_tpu_model
"""
from __future__ import absolute_import
@@ -58,6 +61,8 @@ from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
+from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
+from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy
from tensorflow.contrib.tpu.python.tpu.topology import *
from tensorflow.contrib.tpu.python.tpu.tpu import *
from tensorflow.contrib.tpu.python.tpu.tpu_config import *
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index f80f5652af..8e6e9aa0cd 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -84,8 +84,6 @@ ProfileRequest PopulateProfileRequest(int duration_ms,
request.add_tools("memory_viewer");
request.add_tools("overview_page");
*request.mutable_opts() = opts;
- std::cout << "Limiting the number of trace events to " << kMaxEvents
- << std::endl;
return request;
}
@@ -99,7 +97,6 @@ bool Profile(const string& service_addr, const string& logdir, int duration_ms,
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
- // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available.
// TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
// `ValidateHostPortPair` checks for empty host string case.
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
@@ -166,6 +163,85 @@ bool NewSession(const string& service_addr,
return new_session_response.empty_trace();
}
+// Starts tracing on a single or multiple TPU hosts and saves the result in the
+// given logdir. If no trace was collected, retries tracing for
+// num_tracing_attempts.
+void StartTracing(const tensorflow::string& service_addr,
+ const tensorflow::string& logdir,
+ const tensorflow::string& workers_list,
+ bool include_dataset_ops, int duration_ms,
+ int num_tracing_attempts) {
+ // Use the current timestamp as the run name.
+ tensorflow::string session_id = GetCurrentTimeStampAsString();
+ constexpr char kProfilePluginDirectory[] = "plugins/profile/";
+ tensorflow::string repository_root =
+ io::JoinPath(logdir, kProfilePluginDirectory);
+ std::vector<tensorflow::string> hostnames =
+ tensorflow::str_util::Split(workers_list, ",");
+
+ bool empty_trace = false;
+ int remaining_attempts = num_tracing_attempts;
+ tensorflow::ProfileOptions opts;
+ opts.set_include_dataset_ops(include_dataset_ops);
+ while (true) {
+ std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
+ << "Remaining attempt(s): " << remaining_attempts-- << std::endl;
+ if (hostnames.empty()) {
+ empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms,
+ repository_root, session_id, opts);
+ } else {
+ tensorflow::string tpu_master = service_addr;
+ empty_trace =
+ tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms,
+ repository_root, session_id, opts);
+ }
+ if (remaining_attempts <= 0 || !empty_trace) break;
+ std::cout << "No trace event is collected. Automatically retrying."
+ << std::endl
+ << std::endl;
+ }
+
+ if (empty_trace) {
+ std::cout << "No trace event is collected after " << num_tracing_attempts
+ << " attempt(s). "
+ << "Perhaps, you want to try again (with more attempts?)."
+ << std::endl
+ << "Tip: increase number of attempts with --num_tracing_attempts."
+ << std::endl;
+ }
+}
+
+MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) {
+ MonitorRequest request;
+ request.set_duration_ms(duration_ms);
+ request.set_monitoring_level(monitoring_level);
+ return request;
+}
+
+// Repeatedly collects profiles and shows user-friendly metrics for
+// 'num_queries' time(s).
+void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
+ int monitoring_level, int num_queries) {
+ for (int query = 0; query < num_queries; ++query) {
+ MonitorRequest request =
+ PopulateMonitorRequest(duration_ms, monitoring_level);
+
+ ::grpc::ClientContext context;
+ ::grpc::ChannelArguments channel_args;
+ channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
+ std::numeric_limits<int32>::max());
+ std::unique_ptr<TPUProfiler::Stub> stub =
+ TPUProfiler::NewStub(::grpc::CreateCustomChannel(
+ "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
+ channel_args));
+ MonitorResponse response;
+ TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response)));
+
+ std::cout << "Xprof Monitoring Results (Sample " << query + 1 << "):\n\n"
+ << response.data() << std::flush;
+ }
+}
+
} // namespace
} // namespace tpu
} // namespace tensorflow
@@ -174,9 +250,11 @@ int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
tensorflow::string FLAGS_workers_list;
- int FLAGS_duration_ms = 2000;
+ int FLAGS_duration_ms = 0;
int FLAGS_num_tracing_attempts = 3;
bool FLAGS_include_dataset_ops = true;
+ int FLAGS_monitoring_level = 0;
+ int FLAGS_num_queries = 100;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
"Address of TPU profiler service e.g. localhost:8466"),
@@ -186,21 +264,38 @@ int main(int argc, char** argv) {
tensorflow::Flag("logdir", &FLAGS_logdir,
"Path of TensorBoard log directory e.g. /tmp/tb_log, "
"gs://tb_bucket"),
- tensorflow::Flag("duration_ms", &FLAGS_duration_ms,
- "Duration of tracing in ms. Default is 2000ms."),
+ tensorflow::Flag(
+ "duration_ms", &FLAGS_duration_ms,
+ "Duration of tracing or monitoring in ms. Default is 2000ms for "
+ "tracing and 1000ms for monitoring."),
tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts,
"Automatically retry N times when no trace event "
"is collected. Default is 3."),
tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops,
"Set to false to profile longer TPU device traces."),
- };
+ tensorflow::Flag("monitoring_level", &FLAGS_monitoring_level,
+ "Choose a monitoring level between 1 and 2 to monitor "
+ "your TPU job continuously. Level 2 is more verbose "
+ "than level 1 and shows more metrics."),
+ tensorflow::Flag("num_queries", &FLAGS_num_queries,
+ "This script will run monitoring for num_queries before "
+ "it stops.")};
std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION
<< std::endl;
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
- if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) {
+ if (!parse_ok || FLAGS_service_addr.empty() ||
+ (FLAGS_logdir.empty() && FLAGS_monitoring_level == 0)) {
+ // Fail if flags are not parsed correctly or service_addr not provided.
+ // Also, fail if neither logdir is provided (required for tracing) nor
+ // monitoring level is provided (required for monitoring).
+ std::cout << usage.c_str() << std::endl;
+ return 2;
+ }
+ if (FLAGS_monitoring_level < 0 || FLAGS_monitoring_level > 2) {
+ // Invalid monitoring level.
std::cout << usage.c_str() << std::endl;
return 2;
}
@@ -213,52 +308,27 @@ int main(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
- // Sets the minimum duration_ms and tracing attempts to one.
- int duration_ms = std::max(FLAGS_duration_ms, 1);
- int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1);
- tensorflow::ProfileOptions opts;
- opts.set_include_dataset_ops(FLAGS_include_dataset_ops);
- tensorflow::ProfileResponse response;
-
- // Use the current timestamp as the run name.
- tensorflow::string session_id =
- tensorflow::tpu::GetCurrentTimeStampAsString();
- constexpr char kProfilePluginDirectory[] = "plugins/profile/";
- tensorflow::string repository_root =
- ::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory);
- std::vector<tensorflow::string> hostnames =
- tensorflow::str_util::Split(FLAGS_workers_list, ",");
-
- bool empty_trace = false;
- while (true) {
- std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
- << "Remaining attempt(s): " << remaining_attempts-- << std::endl;
- if (hostnames.empty()) {
- empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir,
- duration_ms, repository_root,
- session_id, opts);
- } else {
- tensorflow::string tpu_master = FLAGS_service_addr;
- empty_trace =
- tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms,
- repository_root, session_id, opts);
- }
- if (remaining_attempts <= 0 || !empty_trace) break;
- std::cout << "No trace event is collected. Automatically retrying."
- << std::endl
- << std::endl;
+ // Sets the minimum duration_ms, tracing attempts and num queries.
+ int duration_ms = std::max(FLAGS_duration_ms, 0);
+ if (duration_ms == 0) {
+ // If profiling duration was not set by user or set to a negative value, we
+ // set it to default values of 2000ms for tracing and 1000ms for monitoring.
+ duration_ms = FLAGS_monitoring_level == 0 ? 2000 : 1000;
}
+ int num_tracing_attempts = std::max(FLAGS_num_tracing_attempts, 1);
+ int num_queries = std::max(FLAGS_num_queries, 1);
- if (empty_trace) {
- std::cout << "No trace event is collected after "
- << FLAGS_num_tracing_attempts << " attempt(s). "
- << "Perhaps, you want to try again (with more attempts?)."
- << std::endl
- << "Tip: increase number of attempts with --num_tracing_attempts."
+ if (FLAGS_monitoring_level != 0) {
+ std::cout << "Since monitoring level is provided, profile "
+ << FLAGS_service_addr << " for " << duration_ms
+ << "ms and show metrics for " << num_queries << " time(s)."
<< std::endl;
- // Don't dump profile data if no trace is collected.
- return 0;
+ tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms,
+ FLAGS_monitoring_level, num_queries);
+ } else {
+ tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir,
+ FLAGS_workers_list, FLAGS_include_dataset_ops,
+ duration_ms, num_tracing_attempts);
}
-
return 0;
}
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 7a5d01cca4..438f442848 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -50,7 +50,8 @@ flags.DEFINE_string(
flags.DEFINE_string(
'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
'gs://tb_bucket')
-flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.')
+flags.DEFINE_integer('duration_ms', 0,
+ 'Duration of tracing or monitoring in ms.')
flags.DEFINE_integer(
'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
'event is collected.')
@@ -58,6 +59,14 @@ flags.DEFINE_boolean('include_dataset_ops', True,
'Set to false to profile longer TPU '
'device traces.')
+# Monitoring parameters
+flags.DEFINE_integer(
+ 'monitoring_level', 0, 'Choose a monitoring level between '
+ '1 and 2 to monitor your TPU job continuously.')
+flags.DEFINE_integer(
+ 'num_queries', 100,
+ 'This script will run monitoring for num_queries before it stops.')
+
FLAGS = flags.FLAGS
EXECUTABLE = 'data/capture_tpu_profile'
JOB_NAME = 'worker'
@@ -118,6 +127,8 @@ def main(unused_argv=None):
cmd.append('--duration_ms=' + str(FLAGS.duration_ms))
cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts))
cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower())
+ cmd.append('--monitoring_level=' + str(FLAGS.monitoring_level))
+ cmd.append('--num_queries=' + str(FLAGS.num_queries))
subprocess.call(cmd)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 81798ee423..ff893a722f 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -55,7 +55,6 @@ import time
import numpy as np
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
-from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
@@ -82,7 +81,11 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name
+
+# Work-around dependency cycle between DistributionStrategy and TPU lib.
+def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
+ from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
+ return tpu_strategy.TPUStrategy(*args, **kw)
class TPUEmbedding(embeddings.Embedding):
@@ -1130,7 +1133,7 @@ Output shape: %(output_shape)s
'layer': layer,
'input_shape': layer.input_shape,
'output_shape': layer.output_shape
- })
+ })
@experimental
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 92c1eaba71..7994c2c6c7 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -970,8 +970,15 @@ def rewrite(computation,
Args:
computation: A Python function that builds a computation to apply
to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors. If the function returns m outputs, rewrite
- will return a list of m tensors.
+ a list of n tensors.
+
+ `computation` may return a list of operations and tensors. Tensors must
+ come before operations in the returned list. The return value of
+ `rewrite` is a list of tensors corresponding to the tensors from the
+ from `computation`.
+
+ All `Operation`s returned from `computation` will be executed when
+ evaluating any of the returned output tensors.
inputs: A list of input tensors or `None` (equivalent to an empty list).
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index bd8f2c99a8..c104b2403c 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -2886,7 +2886,8 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
multi_tpu_predict_steps_on_single_shard,
inputs=[],
num_shards=num_cores,
- outputs_from_all_shards=False)
+ outputs_from_all_shards=False,
+ device_assignment=ctx.device_assignment)
scaffold = _get_scaffold(captured_scaffold_fn)
return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get()
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index f7fd66d33f..01bac891da 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import evaluation
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
@@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir,
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
stop_time = time.time() + timeout if timeout is not None else None
while True:
- checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None or checkpoint_path == last_checkpoint:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py
index 4877c010fa..94cf7788b2 100644
--- a/tensorflow/contrib/training/python/training/training_test.py
+++ b/tensorflow/contrib/training/python/training/training_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@@ -421,7 +422,7 @@ class TrainTest(test.TestCase):
train_op = self.create_train_op()
model_variables = variables_lib2.global_variables()
- model_path = saver_lib.latest_checkpoint(logdir1)
+ model_path = checkpoint_management.latest_checkpoint(logdir1)
assign_fn = variables_lib.assign_from_checkpoint_fn(
model_path, model_variables)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 326b040ad8..385a14eb44 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2238,6 +2238,7 @@ cc_library(
linkopts = ["-ldl"],
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
+ "//tensorflow/core/platform/default/build_config:logging",
],
)
@@ -2266,6 +2267,7 @@ cc_library(
linkopts = ["-ldl"],
deps = [
"//tensorflow/core/platform/default/build_config:gif",
+ "//tensorflow/core/platform/default/build_config:logging",
],
)
@@ -2292,6 +2294,7 @@ cc_library(
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
+ "//tensorflow/core/platform/default/build_config:logging",
"@png_archive//:png",
],
)
diff --git a/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt b/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt
index ad1ada8d71..3134fceeca 100644
--- a/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Ceil.pbtxt
@@ -1,4 +1,4 @@
op {
graph_op_name: "Ceil"
- summary: "Returns element-wise smallest integer in not less than x."
+ summary: "Returns element-wise smallest integer not less than x."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt
index ea5669693e..dfd199d012 100644
--- a/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorGetNext.pbtxt
@@ -1,4 +1,4 @@
op {
graph_op_name: "IteratorGetNext"
- summary: "Gets the next output from the given iterator."
+ summary: "Gets the next output from the given iterator ."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt
new file mode 100644
index 0000000000..7068336847
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorGetNextAsOptional.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "IteratorGetNextAsOptional"
+ summary: "Gets the next output from the given iterator as an Optional variant."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OptionalFromValue.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptionalFromValue.pbtxt
new file mode 100644
index 0000000000..4a15eea424
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OptionalFromValue.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalFromValue"
+ summary: "Constructs an Optional variant from a tuple of tensors."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OptionalGetValue.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptionalGetValue.pbtxt
new file mode 100644
index 0000000000..11c0c545d0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OptionalGetValue.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalGetValue"
+ summary: "Returns the value stored in an Optional variant or raises an error if none exists."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OptionalHasValue.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptionalHasValue.pbtxt
new file mode 100644
index 0000000000..7669178427
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OptionalHasValue.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalHasValue"
+ summary: "Returns true if and only if the given Optional variant has a value."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OptionalNone.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptionalNone.pbtxt
new file mode 100644
index 0000000000..150062a704
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OptionalNone.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalNone"
+ summary: "Creates an Optional variant with no value."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IteratorGetNextAsOptional.pbtxt b/tensorflow/core/api_def/python_api/api_def_IteratorGetNextAsOptional.pbtxt
new file mode 100644
index 0000000000..a88f422c21
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IteratorGetNextAsOptional.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "IteratorGetNextAsOptional"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_OptionalFromValue.pbtxt b/tensorflow/core/api_def/python_api/api_def_OptionalFromValue.pbtxt
new file mode 100644
index 0000000000..c4949258e6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_OptionalFromValue.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalFromValue"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_OptionalGetValue.pbtxt b/tensorflow/core/api_def/python_api/api_def_OptionalGetValue.pbtxt
new file mode 100644
index 0000000000..e3d362ac6e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_OptionalGetValue.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalGetValue"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_OptionalHasValue.pbtxt b/tensorflow/core/api_def/python_api/api_def_OptionalHasValue.pbtxt
new file mode 100644
index 0000000000..7f5a96982a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_OptionalHasValue.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalHasValue"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_OptionalNone.pbtxt b/tensorflow/core/api_def/python_api/api_def_OptionalNone.pbtxt
new file mode 100644
index 0000000000..15d11c4169
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_OptionalNone.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "OptionalNone"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index 630b3702c8..f8cb854b52 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -340,4 +340,30 @@ Status CopyTensor::Register(DeviceType sender_device_type,
return Status::OK();
}
+namespace {
+
+// The following registrations enable a DT_VARIANT tensor element that contains
+// a wrapped `tensorflow::Tensor` to be copied between devices.
+static Status WrappedTensorDeviceCopy(
+ const Tensor& from, Tensor* to,
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
+ if (DMAHelper::CanUseDMA(&from)) {
+ TF_RETURN_IF_ERROR(copy(from, to));
+ } else {
+ *to = from;
+ }
+
+ return Status::OK();
+}
+
+#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+
+REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
+REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
+REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
+
+} // namespace
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index b623ed4421..6ab2d1ebf1 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -47,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
async_default_(async),
+ env_(opts.env),
use_send_tensor_rpc_(false) {
InitDeviceMapAndAsync();
if (opts.config.inter_op_parallelism_threads() > 0) {
@@ -58,34 +59,6 @@ EagerContext::EagerContext(const SessionOptions& opts,
}
}
-#ifndef __ANDROID__
-EagerContext::EagerContext(
- const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
- bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<ServerInterface> server,
- std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<DeviceMgr> remote_device_manager,
- const gtl::FlatMap<string, uint64>& remote_contexts)
- : policy_(default_policy),
- local_unowned_device_manager_(local_device_mgr),
- devices_(local_unowned_device_manager_->ListDevices()),
- rendezvous_(rendezvous),
- thread_pool_(NewThreadPoolFromSessionOptions(opts)),
- pflr_(new ProcessFunctionLibraryRuntime(
- local_unowned_device_manager_, opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
- log_device_placement_(opts.config.log_device_placement()),
- async_default_(async),
- remote_device_manager_(std::move(remote_device_manager)),
- server_(std::move(server)),
- remote_eager_workers_(std::move(remote_eager_workers)),
- remote_contexts_(remote_contexts),
- use_send_tensor_rpc_(
- ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) {
- InitDeviceMapAndAsync();
-}
-#endif
-
void EagerContext::InitDeviceMapAndAsync() {
if (async_default_) {
executor_.EnableAsync();
@@ -148,15 +121,8 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
return policy_;
}
-EagerContext::~EagerContext() {
#ifndef __ANDROID__
- if (server_) {
- // TODO(nareshmodi): Fix this.
- LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
- "Servers don't support clean shutdown.";
- server_.release();
- }
-
+void EagerContext::CloseRemoteContexts() {
// Close all remote contexts.
std::vector<eager::CloseContextRequest> requests(remote_contexts_.size());
std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
@@ -183,6 +149,19 @@ EagerContext::~EagerContext() {
}
counter.Wait();
+}
+#endif
+
+EagerContext::~EagerContext() {
+#ifndef __ANDROID__
+ if (server_) {
+ // TODO(nareshmodi): Fix this.
+ LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
+ "Servers don't support clean shutdown.";
+ server_.release();
+ }
+
+ CloseRemoteContexts();
#endif
executor_.WaitForAllPendingNodes().IgnoreError();
@@ -318,6 +297,55 @@ Status EagerContext::GetClientAndContextID(Device* device,
return Status::OK();
}
+
+void EagerContext::InitializeRemote(
+ std::unique_ptr<ServerInterface> server,
+ std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
+ std::unique_ptr<DeviceMgr> remote_device_manager,
+ const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
+ DeviceMgr* local_device_mgr) {
+ if (!remote_contexts_.empty()) {
+ CloseRemoteContexts();
+ }
+ remote_contexts_ = remote_contexts;
+
+ use_send_tensor_rpc_ =
+ ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
+
+ local_unowned_device_manager_ = local_device_mgr;
+ local_device_manager_ = nullptr;
+ pflr_.reset(new ProcessFunctionLibraryRuntime(
+ local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
+ {}, thread_pool_.get()));
+
+ devices_ = local_unowned_device_manager_->ListDevices();
+ devices_map_.clear();
+
+ if (rendezvous_ != nullptr) rendezvous_->Unref();
+ rendezvous_ = r;
+
+ // Memory leak!
+ if (server_ != nullptr) {
+ LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
+ "Servers don't support clean shutdown.";
+ server_.release();
+ }
+
+ server_ = std::move(server);
+ remote_eager_workers_ = std::move(remote_eager_workers);
+
+ active_remote_contexts_.clear();
+ for (const auto& remote_context : remote_contexts_) {
+ active_remote_contexts_.insert(remote_context.second);
+ }
+
+ device_to_client_cache_.clear();
+ remote_device_manager_ = std::move(remote_device_manager);
+
+ InitDeviceMapAndAsync();
+
+ ClearCaches();
+}
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 9c8c599452..a0b612e6e5 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@@ -68,31 +69,6 @@ class EagerContext {
ContextDevicePlacementPolicy default_policy, bool async,
std::unique_ptr<DeviceMgr> device_mgr,
Rendezvous* rendezvous);
-
- // TODO(nareshmodi): Split this into 2 classes and hide functionality behind
- // an interface. Alternatively, encapsulate remote state into a separate
- // class/struct.
- //
- // Constructs an eager context that is able to communicate with remote
- // workers.
- //
- // Additional remote-specific args are:
- // - server: A ServerInterface that exports the tensorflow.WorkerService.
- // Note that this class expects the server to already have been started.
- // - remote_eager_workers: A cache from which we can get "EagerClient"s to
- // communicate with remote eager services.
- // - remote_device_mgr: A DeviceMgr* which contains all remote devices
- // (should contain no local devices).
- // - remote_contexts: A map containing task name to remote context ID.
-#ifndef __ANDROID__
- explicit EagerContext(
- const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
- bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<ServerInterface> server,
- std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<DeviceMgr> remote_device_manager,
- const gtl::FlatMap<string, uint64>& remote_contexts);
-#endif
~EagerContext();
// Returns the function library runtime for the given device.
@@ -183,7 +159,31 @@ class EagerContext {
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
uint64* context_id);
+ // TODO(nareshmodi): Encapsulate remote state into a separate
+ // class/struct.
+ //
+ // Enables the eager context to communicate with remote devices.
+ //
+ // - server: A ServerInterface that exports the tensorflow.WorkerService.
+ // Note that this class expects the server to already have been started.
+ // - remote_eager_workers: A cache from which we can get "EagerClient"s to
+ // communicate with remote eager services.
+ // - remote_device_mgr: A DeviceMgr* which contains all remote devices
+ // (should contain no local devices).
+ // - remote_contexts: A map containing task name to remote context ID.
+ void InitializeRemote(
+ std::unique_ptr<ServerInterface> server,
+ std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
+ std::unique_ptr<DeviceMgr> remote_device_manager,
+ const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
+ DeviceMgr* local_device_mgr);
+
+ bool HasActiveRemoteContext(uint64 context_id) {
+ return active_remote_contexts_.find(context_id) !=
+ active_remote_contexts_.end();
+ }
#endif
+
// If true, then tensors should be shipped across processes via the
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs.
@@ -203,13 +203,13 @@ class EagerContext {
// Only one of the below is set.
std::unique_ptr<DeviceMgr> local_device_manager_;
- const DeviceMgr* local_unowned_device_manager_;
+ DeviceMgr* local_unowned_device_manager_;
// Devices owned by device_manager
std::vector<Device*> devices_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
- Rendezvous* const rendezvous_;
+ Rendezvous* rendezvous_;
mutex functions_mu_;
FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
@@ -220,7 +220,7 @@ class EagerContext {
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
// session->devices[i].
- const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
std::function<void(std::function<void()>)> runner_;
@@ -243,21 +243,25 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
- const std::unique_ptr<DeviceMgr> remote_device_manager_;
+ Env* const env_;
#ifndef __ANDROID__
+ void CloseRemoteContexts();
+ std::unique_ptr<DeviceMgr> remote_device_manager_;
+
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
// be).
std::unique_ptr<ServerInterface> server_;
- const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
+ std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
- const gtl::FlatMap<string, uint64> remote_contexts_;
+ gtl::FlatMap<string, uint64> remote_contexts_;
+ gtl::FlatSet<uint64> active_remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
#endif
- const bool use_send_tensor_rpc_;
+ bool use_send_tensor_rpc_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 333778ddf8..3837405e7f 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -129,7 +129,7 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
- auto pre_time = Env::Default()->NowMicros();
+ auto pre_time_nanos = Env::Default()->NowNanos();
TensorHandle* result_handle = nullptr;
Status status = EagerCopyToDevice(
*handle, ctx, expected_device->name().c_str(), &result_handle);
@@ -141,8 +141,13 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
auto* node_stats = dev_stats->add_node_stats();
node_stats->set_node_name("_Send");
- node_stats->set_all_start_micros(pre_time);
- node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() - pre_time);
+ node_stats->set_all_start_micros(pre_time_nanos /
+ EnvTime::kMicrosToNanos);
+ node_stats->set_all_start_nanos(pre_time_nanos);
+ int64 now_nanos = Env::Default()->NowNanos();
+ node_stats->set_op_end_rel_micros((now_nanos - pre_time_nanos) /
+ EnvTime::kMicrosToNanos);
+ node_stats->set_op_end_rel_nanos(now_nanos - pre_time_nanos);
}
if (!status.ok()) {
if (result_handle != nullptr) result_handle->Unref();
@@ -206,222 +211,6 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
ndef.DebugString());
}
-#ifdef TENSORFLOW_EAGER_USE_XLA
-// Synthesizes and returns a wrapper function over `op`, which must be a
-// primitive op (e.g. matmul).
-//
-// The wrapper function conforms to the function signature expected by
-// XlaLaunch, with input params ordered by <constants, (variable) args and
-// resources>. For example, if the op has input params <Const1, Arg2, Const3,
-// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
-// Resource4> as the input params to the synthesized function.
-//
-// It populates `const_input_types`, `arg_input_types` and
-// `op_input_to_func_input` based on the reordering results, that the caller
-// can use them to build an XlaLaunch. On error, it returns NULL, and sets
-// `status` accordingly.
-const FunctionDef* OpToFunction(TFE_Op* op,
- std::vector<TF_DataType>* const_input_types,
- std::vector<TF_DataType>* arg_input_types,
- gtl::FlatMap<int, int>* op_input_to_func_input,
- TF_Status* status) {
- DCHECK(!op->operation.is_function());
-
- FunctionDef fdef;
-
- // Get the OpDef of the op we are trying to encapsulate.
- TFE_Context* ctx = op->operation.ctx;
- const OpRegistrationData* op_data;
- {
- status = ctx->context.FindFunctionOpData(op->operation.Name(), &op_data);
- if (!status.ok()) {
- return nullptr;
- }
- }
- const OpDef& op_def = op_data->op_def;
-
- OpDef* signature = fdef.mutable_signature();
-
- // Handle constant inputs.
- const std::unordered_set<string> const_inputs(
- *XlaOpRegistry::CompileTimeConstantInputs(op->operation.Name()));
-
- // First add place holders for the input args, so that we can refer to them
- // by position in the next loop. Also tally up the resource inputs.
- int num_resource_inputs = 0;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- if (op_def.input_arg(i).type() == DT_RESOURCE) {
- ++num_resource_inputs;
- }
- signature->add_input_arg();
- }
-
- // Now we map the input params from `op_def` to `signature`, where the param
- // ordering for `signature` is: <constants, args, resources>.
- int const_index = 0;
- int arg_index = const_inputs.size();
- int resource_index = op_def.input_arg_size() - num_resource_inputs;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- const OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
- OpDef::ArgDef* func_input_arg = nullptr;
- if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
- VLOG(1) << "For const input, mapping op input " << i << " to func input "
- << const_index;
- (*op_input_to_func_input)[i] = const_index;
- func_input_arg = signature->mutable_input_arg(const_index++);
- const_input_types->push_back(
- static_cast<TF_DataType>(op->operation.Inputs()[i]->dtype));
- } else if (op_input_arg.type() == DT_RESOURCE) {
- VLOG(1) << "For resource input, mapping op input " << i
- << " to func input " << resource_index;
- (*op_input_to_func_input)[i] = resource_index;
- func_input_arg = signature->mutable_input_arg(resource_index++);
- } else {
- VLOG(1) << "For arg input, mapping op input " << i << " to func input "
- << arg_index;
- (*op_input_to_func_input)[i] = arg_index;
- func_input_arg = signature->mutable_input_arg(arg_index++);
- arg_input_types->push_back(
- static_cast<TF_DataType>(op->operation.Inputs()[i]->dtype));
- }
-
- func_input_arg->set_name(op_input_arg.name());
- func_input_arg->set_type(op->operation.Inputs()[i]->dtype);
- }
- VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
-
- // Resources args are at the end of the function input params, and we should
- // have iterated over all of them.
- DCHECK_EQ(signature->input_arg_size(), resource_index);
-
- // Make the synthesized function's name unique.
- signature->set_name(
- strings::StrCat(op_def.name(), func_id_generator.fetch_add(1)));
-
- // Add the node def and set its input names to match op_def's names.
- const NodeDef& ndef = op->operation.MutableAttrs()->BuildNodeDef();
- DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
- *fdef.add_node_def() = ndef;
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
- }
- VLOG(1) << "Added NodeDef: " << fdef.DebugString();
-
- // Fix the output names and set output types.
- for (int i = 0; i < op_def.output_arg_size(); ++i) {
- OpDef::ArgDef* arg = signature->add_output_arg();
- const OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
- const string& out_tensor_name =
- strings::StrCat(ndef.name(), ":", op_def_arg.name(), ":", 0);
- arg->set_name(op_def_arg.name());
- (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
- const string& type_attr = op_def_arg.type_attr();
- if (!type_attr.empty()) {
- auto i = ndef.attr().find(type_attr);
- if (i == ndef.attr().end()) {
- status = errors::InvalidArgument(
- strings::StrCat("Could not find attr ", type_attr, " in NodeDef ",
- ndef.DebugString()));
- return nullptr;
- }
- arg->set_type(i->second.type());
- }
- }
- VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
-
- status = ctx->context.AddFunctionDef(fdef);
- if (!status.ok()) return nullptr;
- const auto ret = ctx->context.FindFunctionDef(signature->name());
- DCHECK(ret != nullptr);
- return ret;
-}
-
-// Builds an XlaLaunch as a wrapper over 'op', so that 'op' can be executed
-// via XLA.
-std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
- VLOG(1) << "Creating XlaLaunch for TFE_Op " << op->operation.Name();
- auto launch_op = std::unique_ptr<TFE_Op>(
- TFE_NewOp(op->operation.ctx, "XlaLaunch", status));
- if (TF_GetCode(status) != TF_OK) return nullptr;
- if (op->operation.device) {
- TFE_OpSetDevice(launch_op.get(), op->operation.device->name().c_str(),
- status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- }
-
- const FunctionDef* fdef;
- { fdef = op->operation.ctx->FindFunctionDef(op->operation.Name()); }
- std::vector<TF_DataType> const_input_types;
- std::vector<TF_DataType> arg_input_types;
- gtl::FlatMap<int, int> op_input_to_func_input;
- if (fdef == nullptr) {
- // See if this is a primitive op, and if so create a function for it, so
- // that XlaLaunch can access it.
- fdef = OpToFunction(op, &const_input_types, &arg_input_types,
- &op_input_to_func_input, status);
- if (!status.ok()) return nullptr;
- } else {
- // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work
- // for functions, so we need to find another way to handle constant
- // inputs.
- for (int i = const_input_types.size();
- i < fdef->signature().input_arg_size(); ++i) {
- VLOG(1) << "Adding Targs from input arg " << i;
- const OpDef::ArgDef& arg = fdef->signature().input_arg(i);
- arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
- }
- }
- DCHECK(fdef != nullptr);
-
- // Copy inputs and their devices.
- // Since input param reordering may have occurred between `op` and
- // `launch_op` via `op_input_to_func_input`, adjust the actual inputs
- // accordingly.
- *launch_op->operation.MutableInputs() = op->operation.Inputs();
- for (TensorHandle* h : launch_op->operation.Inputs()) {
- h->Ref();
- }
- if (!op_input_to_func_input.empty()) {
- DCHECK_EQ(op->operation.Inputs().size(), op_input_to_func_input.size());
- for (int i = 0; i < op_input_to_func_input.size(); ++i) {
- VLOG(1) << "mapping op input " << i << " to func input "
- << op_input_to_func_input[i];
-
- (*launch_op->operation.MuableInputs())[op_input_to_func_input[i]] =
- op->operation.Inputs()[i];
- }
- }
- launch_op->operation.MutableAttrs()->NumInputs(op->operation.Inputs().size());
-
- TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
- const_input_types.size());
-
- // Set Targs and Nresources attrs.
- TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
- arg_input_types.size());
- const int num_resource_inputs = fdef->signature().input_arg_size() -
- const_input_types.size() -
- arg_input_types.size();
- TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
-
- // Set Tresults attr.
- std::vector<TF_DataType> tresults;
- for (const OpDef::ArgDef& arg : fdef->signature().output_arg()) {
- tresults.push_back(static_cast<TF_DataType>(arg.type()));
- }
- TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
- tresults.size());
-
- // Set function attr.
- AttrValue attr_value;
- NameAttrList* func = attr_value.mutable_func();
- func->set_name(fdef->signature().name());
- launch_op->attrs.Set("function", attr_value);
-
- return launch_op;
-}
-#endif // TENSORFLOW_EAGER_USE_XLA
-
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
const OpDef* op_def = nullptr;
@@ -462,14 +251,6 @@ Status EagerLocalExecute(EagerOperation* op,
EagerContext* ctx = op->EagerContext();
auto status = ctx->GetStatus();
if (!status.ok()) return status;
-#ifdef TENSORFLOW_EAGER_USE_XLA
- std::unique_ptr<TFE_Op> xla_launch_op;
- if (op->UseXla() && op->Name() != "XlaLaunch") {
- xla_launch_op = BuildXlaLaunch(op, status);
- if (!status.ok()) return status;
- op = xla_launch_op.get();
- }
-#endif // TENSORFLOW_EAGER_USE_XLA
// Ensure all resource-touching ops run in the device the resource is,
// regardless of anything else that has been specified. This is identical to
// the graph mode behavior.
@@ -522,8 +303,14 @@ Status EagerLocalExecute(EagerOperation* op,
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
tf_shared_lock l(*ctx->FunctionsMu());
- status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(),
- kernel);
+ auto* flr = ctx->func_lib(device);
+
+ if (flr == nullptr) {
+ return errors::Unavailable(
+ "Unable to find a FunctionLibraryRuntime corresponding to device ",
+ device->name());
+ }
+ status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel);
if (!status.ok()) {
delete kernel;
return status;
@@ -563,11 +350,15 @@ Status EagerLocalExecute(EagerOperation* op,
if (!status.ok()) return status;
std::unique_ptr<NodeExecStats> maybe_stats;
if (ctx->ShouldStoreMetadata()) {
+ int64 now_nanos = Env::Default()->NowNanos();
maybe_stats.reset(new NodeExecStats);
maybe_stats->set_node_name(op->Name());
- maybe_stats->set_all_start_micros(Env::Default()->NowMicros());
+ maybe_stats->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ maybe_stats->set_all_start_nanos(now_nanos);
maybe_stats->set_op_start_rel_micros(0);
- maybe_stats->set_scheduled_micros(Env::Default()->NowMicros());
+ maybe_stats->set_op_start_rel_nanos(0);
+ maybe_stats->set_scheduled_micros(now_nanos / EnvTime::kMicrosToNanos);
+ maybe_stats->set_scheduled_nanos(now_nanos);
// TODO(apassos) track referenced tensors
}
retvals->resize(*num_retvals);
@@ -598,6 +389,13 @@ std::function<void()> GetRemoteTensorDestructor(
EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
uint64 op_id, int output_num) {
return [ctx, eager_client, context_id, op_id, output_num]() {
+ if (!ctx->HasActiveRemoteContext(context_id)) {
+ // This means that this tensor was pointing to a remote device, which has
+ // been changed out from under us. Simply return since there is nothing we
+ // can do.
+ return tensorflow::Status::OK();
+ }
+
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(context_id);
@@ -818,6 +616,11 @@ Status EagerExecute(EagerOperation* op,
return EagerLocalExecute(op, retvals, num_retvals);
}
+ if (op->EagerContext()->LogDevicePlacement()) {
+ LOG(INFO) << "Executing op " << op->Name() << " in device "
+ << op->Device()->name();
+ }
+
return EagerRemoteExecute(op, retvals->data(), num_retvals);
}
@@ -852,8 +655,10 @@ Status EagerExecute(EagerContext* ctx, Device* device,
// TODO(agarwal): change Run to take vector of handles ?
TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
if (maybe_stats != nullptr) {
- maybe_stats->set_op_end_rel_micros(Env::Default()->NowMicros() -
+ int64 nanos = Env::Default()->NowNanos();
+ maybe_stats->set_op_end_rel_micros(nanos / EnvTime::kMicrosToNanos -
maybe_stats->all_start_micros());
+ maybe_stats->set_op_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
mutex_lock ml(*ctx->MetadataMu());
if (ctx->ShouldStoreMetadata()) {
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 8096139d90..c2fac4c2c8 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -127,36 +127,52 @@ bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
// Helper routines for collecting step stats.
namespace nodestats {
inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
+inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 t) {
+void SetScheduled(NodeExecStatsWrapper* stats, int64 nanos) {
if (!stats) return;
- stats->stats()->set_scheduled_micros(t);
+ stats->stats()->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats->stats()->set_scheduled_nanos(nanos);
}
void SetAllStart(NodeExecStatsWrapper* stats) {
if (!stats) return;
- stats->stats()->set_all_start_micros(NowInUsec());
+ int64 now_nanos = NowInNsec();
+ stats->stats()->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats->stats()->set_all_start_nanos(now_nanos);
}
void SetOpStart(NodeExecStatsWrapper* stats) {
if (!stats) return;
NodeExecStats* nt = stats->stats();
DCHECK_NE(nt->all_start_micros(), 0);
- nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros());
+ DCHECK_NE(nt->all_start_nanos(), 0);
+ int64 now_nanos = NowInNsec();
+ nt->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ nt->all_start_micros());
+ nt->set_op_start_rel_nanos(now_nanos - nt->all_start_nanos());
}
void SetOpEnd(NodeExecStatsWrapper* stats) {
if (!stats) return;
NodeExecStats* nt = stats->stats();
DCHECK_NE(nt->all_start_micros(), 0);
- nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros());
+ DCHECK_NE(nt->all_start_nanos(), 0);
+ int64 now_nanos = NowInNsec();
+ nt->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ nt->all_start_micros());
+ nt->set_op_end_rel_nanos(now_nanos - nt->all_start_nanos());
}
void SetAllEnd(NodeExecStatsWrapper* stats) {
if (!stats) return;
NodeExecStats* nt = stats->stats();
DCHECK_NE(nt->all_start_micros(), 0);
- nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros());
+ DCHECK_NE(nt->all_start_nanos(), 0);
+ int64 now_nanos = NowInNsec();
+ nt->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ nt->all_start_micros());
+ nt->set_all_end_rel_nanos(now_nanos - nt->all_start_nanos());
}
void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
@@ -1357,7 +1373,7 @@ class ExecutorState {
TaggedNodeSeq* ready);
// Process a ready node in current thread.
- void Process(TaggedNode node, int64 scheduled_usec);
+ void Process(TaggedNode node, int64 scheduled_nsec);
// Before invoking item->kernel, fills in its "inputs".
Status PrepareInputs(const NodeItem& item, Entry* first_input,
@@ -1615,7 +1631,7 @@ struct ExecutorState::AsyncState {
}
};
-void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
+void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
TaggedNodeReadyQueue inline_ready;
@@ -1680,7 +1696,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.track_allocations = true;
stats = new NodeExecStatsWrapper;
stats->stats()->set_node_name(node->name());
- nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
@@ -1823,7 +1839,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
}
if (stats) {
- scheduled_usec = nodestats::NowInUsec();
+ scheduled_nsec = nodestats::NowInNsec();
}
// Postprocess.
completed = NodeDone(s, item.node, ready, stats, &inline_ready);
@@ -2198,14 +2214,14 @@ void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
TaggedNodeReadyQueue* inline_ready) {
if (ready.empty()) return;
- int64 scheduled_usec = 0;
+ int64 scheduled_nsec = 0;
if (stats_collector_) {
- scheduled_usec = nodestats::NowInUsec();
+ scheduled_nsec = nodestats::NowInNsec();
}
if (inline_ready == nullptr) {
// Schedule to run all the ready ops in thread pool.
for (auto& tagged_node : ready) {
- runner_([=]() { Process(tagged_node, scheduled_usec); });
+ runner_([=]() { Process(tagged_node, scheduled_nsec); });
}
return;
}
@@ -2221,7 +2237,7 @@ void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
// Dispatch to another thread since there is plenty of work to
// do for this thread.
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
- scheduled_usec));
+ scheduled_nsec));
}
curr_expensive_node = &tagged_node;
}
@@ -2234,7 +2250,7 @@ void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
// There are inline nodes to run already. We dispatch this expensive
// node to other thread.
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
- scheduled_usec));
+ scheduled_nsec));
}
}
}
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 613470365d..d581f45a90 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -937,8 +937,8 @@ bool Placer::ClientHandlesErrorFormatting() const {
string Placer::RichNodeName(const Node* node) const {
string quoted_name = strings::StrCat("'", node->name(), "'");
if (ClientHandlesErrorFormatting()) {
- string file_and_line = error_format_tag(*node, "${file}:${line}");
- return strings::StrCat(quoted_name, " (defined at ", file_and_line, ")");
+ string file_and_line = error_format_tag(*node, "${defined_at}");
+ return strings::StrCat(quoted_name, file_and_line);
} else {
return quoted_name;
}
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index cede899842..87f2f2ceb9 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -1158,10 +1158,10 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
true);
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(
- str_util::StrContains(s.error_message(),
- "Cannot assign a device for operation 'in'"
- " (defined at ^^node:in:${file}:${line}^^)"));
+ LOG(WARNING) << s.error_message();
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Cannot assign a device for operation 'in'"
+ "^^node:in:${defined_at}^^"));
}
// Test that the "Cannot assign a device" error message does not contain a
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index c1e514d5ad..e26761703b 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -206,6 +206,9 @@ void RingReducer::ContinueAfterInputCopy() {
group_size_tensor_ = group_size_val;
group_size_tensor_ready_.Notify();
}
+ } else {
+ // Value won't be used, so no need to initialize.
+ group_size_tensor_ready_.Notify();
}
Finish(RunAsyncParts());
}
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 2059b1ce0d..b2192c5a80 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -508,6 +508,7 @@ cc_library(
hdrs = ["collective_rma_distributed.h"],
deps = [
":cancellable_call",
+ ":request_id",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index b9a3502131..805e023b0f 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
+#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
@@ -47,6 +48,7 @@ class RecvBufCall : public CancellableCall {
req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
req_.set_src_device(peer_device);
req_.set_dst_device(to_device->name());
+ req_.set_request_id(GetUniqueRequestId());
}
~RecvBufCall() override {}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 61f5369617..1b6d796bd4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -419,7 +419,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
} // namespace
GrpcWorker::GrpcWorker(WorkerEnv* worker_env)
- : Worker(worker_env), recv_tensor_recent_request_ids_(100000) {}
+ : Worker(worker_env), recent_request_ids_(100000) {}
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
// buffers for a response object, to avoid extra protocol buffer serialization
@@ -428,7 +428,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
- Status s = recv_tensor_recent_request_ids_.TrackUnique(
+ Status s = recent_request_ids_.TrackUnique(
request->request_id(), "RecvTensor (GrpcWorker)", *request);
if (!s.ok()) {
done(s);
@@ -508,6 +508,12 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) {
// This is a generic, low performance implementation appropriate for grpc.
+ Status s = recent_request_ids_.TrackUnique(request->request_id(),
+ "RecvBuf (GrpcWorker)", *request);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
CollectiveExecutor::Handle ce_handle(
env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
index c0ed0884bc..d9e48524de 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
@@ -49,7 +49,7 @@ class GrpcWorker : public Worker {
WorkerEnv* env();
private:
- RecentRequestIds recv_tensor_recent_request_ids_;
+ RecentRequestIds recent_request_ids_;
};
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env);
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index b53bd8d53d..b285accce7 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -826,19 +826,6 @@ Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
return Status::OK();
}
-Status OpKernelContext::release_output(StringPiece name, TensorValue* value) {
- int start, stop;
- TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
- if (stop != start + 1) {
- return errors::InvalidArgument("OpKernel used list-valued output name '",
- name,
- "' when single-valued output was "
- "expected");
- }
- *value = release_output(start);
- return Status::OK();
-}
-
bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
const auto& inputs = *params_->inputs;
for (size_t i = 1; i < inputs.size(); ++i) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 2b7cc867da..aab95b785b 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -904,12 +904,6 @@ class OpKernelContext {
// Returns nullptr if allocate_output() or set_output() have not been called.
Status mutable_output(StringPiece name, Tensor** tensor);
- // Transfers ownership of an output tensor to the caller.
- // NOTE: For non-reference outputs, the caller takes responsibility
- // for deletion. For reference outputs, the caller does NOT take
- // responsibility for deletion.
- Status release_output(StringPiece name, TensorValue* value);
-
// Records device specific state about how the input tensors were
// computed.
//
diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto
index d98999cb54..67cc9e3845 100644
--- a/tensorflow/core/framework/step_stats.proto
+++ b/tensorflow/core/framework/step_stats.proto
@@ -67,6 +67,11 @@ message NodeExecStats {
uint32 thread_id = 10;
repeated AllocationDescription referenced_tensor = 11;
MemoryStats memory_stats = 12;
+ int64 all_start_nanos = 13;
+ int64 op_start_rel_nanos = 14;
+ int64 op_end_rel_nanos = 15;
+ int64 all_end_rel_nanos = 16;
+ int64 scheduled_nanos = 17;
};
message DeviceStepStats {
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 384a42fc11..5f805f6594 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -57,6 +57,10 @@ namespace tensorflow {
// Allow Tensors to be stored inside Variants with automatic
// encoding/decoding when those Variants are themselves being decoded
// in a Tensor's FromProto.
+//
+// NOTE(mrry): The corresponding "copy function" registrations can be found in
+// ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
+// code).
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
namespace {
diff --git a/tensorflow/core/framework/tensor_testutil_test.cc b/tensorflow/core/framework/tensor_testutil_test.cc
index 8dca25ac4c..dd321535f2 100644
--- a/tensorflow/core/framework/tensor_testutil_test.cc
+++ b/tensorflow/core/framework/tensor_testutil_test.cc
@@ -249,10 +249,9 @@ TEST(TensorTestUtilTest, ExpectTensorCloseHalf) {
EXPECT_TRUE(Helper<T>::IsClose(HALF(3.141592f), HALF(3.141593f), HALF(0.0),
HALF(0.0)));
- // This case failed because HALF(1e7f) is stored as inf, which it shouldn't.
- // TODO(penporn): Debug Eigen::half and uncomment this test case.
- // EXPECT_FALSE(
- // Helper<T>::IsClose(HALF(1e7f), HALF(1e-7f), kDefaultTol, kDefaultTol));
+ // Trivial case.
+ EXPECT_FALSE(
+ Helper<T>::IsClose(HALF(1e4f), HALF(1e-4f), kDefaultTol, kDefaultTol));
#undef HALF
// Some of the cases failed because Eigen::half doesn't behave as expected.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 6a1b0aebfa..f31d22e105 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -653,39 +653,42 @@ NodeState& VirtualScheduler::GetNodeStateOrCreateIt(const NodeDef* node) {
CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
auto it = node_map_.find(node);
- if (it == node_map_.end()) {
- // Not found; create a NodeState for this node.
- it = node_map_.emplace(node, NodeState()).first;
- auto& node_state = it->second;
- node_state.input_properties =
- graph_properties_.GetInputProperties(node->name());
- node_state.output_properties =
- graph_properties_.GetOutputProperties(node->name());
-
- // Some ops may need further processing to the input / output properties:
- // _Send and _Recv.
- MaybeUpdateInputOutput(node);
-
- if (!IsSend(*node)) {
- node_state.device_name = DeviceName(node);
- // For _Send op, device_name will be set to Channel in CreateSendRecv().
- }
+ if (it != node_map_.end()) {
+ return it->second;
+ }
- // Initialize output port related data:
- // Assume the size of OutputProperties represents the number of output ports
- // of this node.
- for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
- node_state.time_no_references[i] = Costs::Duration::max();
- node_state.num_outputs_executed[i] = 0;
- // Populate an empty vector for each port. The caller will add nodes
- // that use this port as input.
- node_state.outputs[i] = {};
- }
- // Port_num -1 is for control dependency.
- node_state.time_no_references[-1] = Costs::Duration::max();
- node_state.num_outputs_executed[-1] = 0;
- node_state.outputs[-1] = {};
+ // Not found; create a NodeState for this node.
+ it = node_map_.emplace(node, NodeState()).first;
+ auto& node_state = it->second;
+ node_state.input_properties =
+ graph_properties_.GetInputProperties(node->name());
+ node_state.output_properties =
+ graph_properties_.GetOutputProperties(node->name());
+
+ // Some ops may need further processing to the input / output properties:
+ // _Send and _Recv.
+ MaybeUpdateInputOutput(node);
+
+ if (!IsSend(*node)) {
+ node_state.device_name = DeviceName(node);
+ // For _Send op, device_name will be set to Channel in CreateSendRecv().
}
+
+ // Initialize output port related data:
+ // Assume the size of OutputProperties represents the number of output ports
+ // of this node.
+ for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
+ node_state.time_no_references[i] = Costs::Duration::max();
+ node_state.num_outputs_executed[i] = 0;
+ // Populate an empty vector for each port. The caller will add nodes
+ // that use this port as input.
+ node_state.outputs[i] = {};
+ }
+ // Port_num -1 is for control dependency.
+ node_state.time_no_references[-1] = Costs::Duration::max();
+ node_state.num_outputs_executed[-1] = 0;
+ node_state.outputs[-1] = {};
+
return it->second;
}
@@ -859,9 +862,10 @@ Costs VirtualScheduler::Summary() const {
const auto& memory_cost = op_cost_pair.second.memory_time.count();
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (cost) { // Skip printing out zero-cost ops.
- VLOG(1) << strings::Printf(" + %30s : %c %10ld / %10ld / %10ld",
- op.c_str(), (is_op_cost_accurate ? ' ' : '~'),
- cost, compute_cost, memory_cost);
+ VLOG(1) << strings::Printf(
+ " + %30s : %c %10lld / %10lld / %10lld", op.c_str(),
+ (is_op_cost_accurate ? ' ' : '~'), static_cast<int64>(cost),
+ static_cast<int64>(compute_cost), static_cast<int64>(memory_cost));
}
}
@@ -902,7 +906,7 @@ Costs VirtualScheduler::Summary() const {
<< ", at the end: "
<< strings::HumanReadableNumBytes(state.memory_usage);
- VLOG(1) << "Per-op execution time compute time / memory time "
+ VLOG(1) << "Per-op execution time / compute time / memory time "
"(and memory usage at peak memory usage):";
// Profile non-persistent op memory usage.
@@ -936,10 +940,12 @@ Costs VirtualScheduler::Summary() const {
: 0.0;
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
- VLOG(1) << strings::Printf(" + %30s : %c %10ld / %10ld / %10ld",
+ VLOG(1) << strings::Printf(" + %30s : %c %10lld / %10lld / %10lld",
op.c_str(),
- (is_op_cost_accurate ? ' ' : '~'), cost,
- compute_cost, memory_cost)
+ (is_op_cost_accurate ? ' ' : '~'),
+ static_cast<int64>(cost),
+ static_cast<int64>(compute_cost),
+ static_cast<int64>(memory_cost))
<< " (" << strings::HumanReadableNumBytes(op_mem_usage) << " ["
<< mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
@@ -978,55 +984,59 @@ Costs VirtualScheduler::Summary() const {
}
Costs VirtualScheduler::Summary(RunMetadata* metadata) {
- if (metadata != nullptr) {
- StepStats* stepstats = metadata->mutable_step_stats();
- for (const auto& device : device_) {
- GraphDef* device_partition_graph = metadata->add_partition_graphs();
- DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
- device_stepstats->set_device(device.first);
- for (const auto& node_def : device.second.nodes_executed) {
- const NodeState& nodestate = node_map_.at(node_def);
- NodeExecStats* node_stats = device_stepstats->add_node_stats();
- uint64 total_output_size = 0;
- for (int slot = 0; slot < nodestate.output_properties.size(); slot++) {
- const auto& properties = nodestate.output_properties[slot];
- NodeOutput* no = node_stats->add_output();
- no->set_slot(slot);
- TensorDescription* tensor_descr = no->mutable_tensor_description();
- tensor_descr->set_dtype(properties.dtype());
- *tensor_descr->mutable_shape() = properties.shape();
- // Optional allocation description.
- const auto tensor_size =
- CalculateOutputSize(nodestate.output_properties, slot);
- total_output_size += tensor_size;
- tensor_descr->mutable_allocation_description()->set_requested_bytes(
- tensor_size);
- tensor_descr->mutable_allocation_description()->set_allocated_bytes(
- tensor_size);
- }
- node_stats->set_timeline_label(node_def->op());
- node_stats->set_node_name(node_def->name());
- node_stats->set_op_start_rel_micros(0);
- node_stats->set_all_start_micros(
- nodestate.time_scheduled.asMicroSeconds().count());
- node_stats->set_op_end_rel_micros(
- nodestate.time_finished.asMicroSeconds().count() -
- nodestate.time_scheduled.asMicroSeconds().count());
- node_stats->set_all_end_rel_micros(
- nodestate.time_finished.asMicroSeconds().count() -
- nodestate.time_scheduled.asMicroSeconds().count());
- auto* mem_stats = node_stats->mutable_memory_stats();
- // VirtualScheduler does not specify scratch pad memory usage.
- mem_stats->set_temp_memory_size(0);
- int64 persistent_memory_size = 0;
- if (IsPersistentNode(node_def)) {
- persistent_memory_size = total_output_size;
- }
- mem_stats->set_persistent_memory_size(persistent_memory_size);
- *device_partition_graph->add_node() = *node_def;
+ if (!metadata) {
+ return Summary();
+ }
+
+ // Fill RunMetadata.
+ StepStats* stepstats = metadata->mutable_step_stats();
+ for (const auto& device : device_) {
+ GraphDef* device_partition_graph = metadata->add_partition_graphs();
+ DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
+ device_stepstats->set_device(device.first);
+ for (const auto& node_def : device.second.nodes_executed) {
+ const NodeState& nodestate = node_map_.at(node_def);
+ NodeExecStats* node_stats = device_stepstats->add_node_stats();
+ uint64 total_output_size = 0;
+ for (int slot = 0; slot < nodestate.output_properties.size(); slot++) {
+ const auto& properties = nodestate.output_properties[slot];
+ NodeOutput* no = node_stats->add_output();
+ no->set_slot(slot);
+ TensorDescription* tensor_descr = no->mutable_tensor_description();
+ tensor_descr->set_dtype(properties.dtype());
+ *tensor_descr->mutable_shape() = properties.shape();
+ // Optional allocation description.
+ const auto tensor_size =
+ CalculateOutputSize(nodestate.output_properties, slot);
+ total_output_size += tensor_size;
+ tensor_descr->mutable_allocation_description()->set_requested_bytes(
+ tensor_size);
+ tensor_descr->mutable_allocation_description()->set_allocated_bytes(
+ tensor_size);
+ }
+ node_stats->set_timeline_label(node_def->op());
+ node_stats->set_node_name(node_def->name());
+ node_stats->set_op_start_rel_micros(0);
+ node_stats->set_all_start_micros(
+ nodestate.time_scheduled.asMicroSeconds().count());
+ node_stats->set_op_end_rel_micros(
+ nodestate.time_finished.asMicroSeconds().count() -
+ nodestate.time_scheduled.asMicroSeconds().count());
+ node_stats->set_all_end_rel_micros(
+ nodestate.time_finished.asMicroSeconds().count() -
+ nodestate.time_scheduled.asMicroSeconds().count());
+ auto* mem_stats = node_stats->mutable_memory_stats();
+ // VirtualScheduler does not specify scratch pad memory usage.
+ mem_stats->set_temp_memory_size(0);
+ int64 persistent_memory_size = 0;
+ if (IsPersistentNode(node_def)) {
+ persistent_memory_size = total_output_size;
}
+ mem_stats->set_persistent_memory_size(persistent_memory_size);
+ *device_partition_graph->add_node() = *node_def;
}
}
+
return Summary();
}
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 34d48819ac..353ca6f071 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -275,7 +275,6 @@ class VirtualScheduler {
// Return per device peak memory usage.
const std::unordered_map<string, int64> GetPeakMemoryUsage() const;
- protected:
const std::unordered_map<string, DeviceState>* GetDeviceStates() const {
return &device_;
}
@@ -283,6 +282,7 @@ class VirtualScheduler {
return &node_map_;
}
+ protected:
// Returns the size of output at port_num (unit: bytes). A special case is
// port_num -1, which is for control dependency and assumed to be 4 bytes.
int64 CalculateOutputSize(
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 7998f0a902..a6b6b6f8b2 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -22,9 +22,7 @@ namespace grappler {
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
- auto result = nodes_.emplace(node->name(), node);
- // Check that the graph doesn't contain multiple nodes with the same name.
- CHECK(result.second) << "Non unique node name detected: " << node->name();
+ AddUniqueNodeOrDie(node);
}
for (NodeDef& node : *graph_->mutable_node()) {
@@ -32,6 +30,12 @@ GraphView::GraphView(GraphDef* graph) : graph_(graph) {
}
}
+void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
+ auto result = nodes_.emplace(node->name(), node);
+ // Check that the graph doesn't contain multiple nodes with the same name.
+ CHECK(result.second) << "Non unique node name detected: " << node->name();
+}
+
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index 050789d2e2..ac260f85a0 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -115,6 +115,8 @@ class GraphView {
const NodeDef& node, bool include_controlling_edges) const;
protected:
+ // Add a new `node` to the graph.
+ void AddUniqueNodeOrDie(NodeDef* node);
// Add fanout to every `node` input.
void AddFanouts(NodeDef* node);
std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; }
diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc
index 6abafe11a2..f0aff90c6c 100644
--- a/tensorflow/core/grappler/mutable_graph_view.cc
+++ b/tensorflow/core/grappler/mutable_graph_view.cc
@@ -23,10 +23,22 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
auto* node_in_graph = GetGraph()->add_node();
*node_in_graph = std::move(node);
- auto result = MutableNodes()->emplace(node_in_graph->name(), node_in_graph);
- // Check that the graph doesn't contain multiple nodes with the same name.
- CHECK(result.second) << "Non unique node name detected: "
- << node_in_graph->name();
+ AddUniqueNodeOrDie(node_in_graph);
+
+ AddFanouts(node_in_graph);
+ return node_in_graph;
+}
+
+NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
+ const int output_port_id) {
+ auto* node_in_graph = GetGraph()->add_node();
+ *node_in_graph = std::move(node);
+
+ AddUniqueNodeOrDie(node_in_graph);
+
+ // replace input for the output nodes of `input_node` with `node`
+ ReplaceInput(input_node, *node_in_graph, output_port_id);
+
AddFanouts(node_in_graph);
return node_in_graph;
}
diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h
index 105eb972e8..971e5503d4 100644
--- a/tensorflow/core/grappler/mutable_graph_view.h
+++ b/tensorflow/core/grappler/mutable_graph_view.h
@@ -29,9 +29,16 @@ class MutableGraphView : public GraphView {
using GraphView::GraphView;
GraphDef* GetGraph() { return MutableGraph(); }
+
// Adds a new node to graph and updates the view.
NodeDef* AddNode(NodeDef&& node);
+ // Inserts a new node to the graph after `input` node and updates the view.
+ // This adds `node` to the graph and replaces the input for the output
+ // nodes of `input` with a port `output_port_id` with the new node.
+ NodeDef* InsertNode(const NodeDef& input, NodeDef&& node,
+ int output_port_id = 0);
+
// Replaces the input for the output nodes of 'old_input' with a port
// `output_port_id` with 'new_input'.
//
diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc
index f09dfb8271..2536bec35d 100644
--- a/tensorflow/core/grappler/mutable_graph_view_test.cc
+++ b/tensorflow/core/grappler/mutable_graph_view_test.cc
@@ -23,7 +23,18 @@ namespace tensorflow {
namespace grappler {
namespace {
-TEST(MutableGraphViewTest, AddAndReplaceInput) {
+bool FindChildWithName(const MutableGraphView& graph,
+ const string& output_port_name,
+ const string& input_name) {
+ GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0);
+ auto fanout = graph.GetFanout(output_port);
+ for (auto& input_port : fanout) {
+ if (input_port.node->name() == input_name) return true;
+ }
+ return false;
+}
+
+TrivialTestGraphInputYielder SimpleGraph() {
// This outputs simple graph like:
// x
// / \
@@ -35,7 +46,13 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
// AddN AddN_1
// \ /
// y
- TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"});
+ TrivialTestGraphInputYielder simple_graph(2, 2, 2, false,
+ {"/CPU:0", "/GPU:0"});
+ return simple_graph;
+}
+
+TEST(MutableGraphViewTest, AddAndReplaceInput) {
+ TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
@@ -49,18 +66,7 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
EXPECT_EQ("Square", fanin.node->name());
EXPECT_EQ(0, fanin.port_id);
- auto find_child_with_name = [&graph](string output_port_name,
- string input_name) {
- GraphView::OutputPort output_port =
- graph.GetOutputPort(output_port_name, 0);
- auto fanout = graph.GetFanout(output_port);
- for (auto& input_port : fanout) {
- if (input_port.node->name() == input_name) return true;
- }
- return false;
- };
-
- EXPECT_FALSE(find_child_with_name("Square", "new_node"));
+ EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node"));
NodeDef new_node = *input.node;
new_node.set_name("new_node");
@@ -70,13 +76,40 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
EXPECT_NE(graph.GetNode("new_node"), nullptr);
graph.ReplaceInput(*input.node, *node_in_graph);
- EXPECT_TRUE(find_child_with_name("Square", "new_node"));
- EXPECT_TRUE(find_child_with_name("new_node", "y"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node"));
+ EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
+}
+
+TEST(MutableGraphViewTest, InsertNodes) {
+ TrivialTestGraphInputYielder fake_input = SimpleGraph();
+
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphDef new_graph = item.graph;
+ MutableGraphView graph(&new_graph);
+
+ GraphView::InputPort input = graph.GetInputPort("AddN", 0);
+
+ NodeDef new_node = *input.node;
+ new_node.set_name("new_node");
+ new_node.set_input(0, input.node->name());
+
+ EXPECT_EQ(graph.GetNode("new_node"), nullptr);
+ graph.InsertNode(*input.node, std::move(new_node));
+ EXPECT_NE(graph.GetNode("new_node"), nullptr);
+ EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN"));
+ EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1"));
+ EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node"));
+ EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y"));
+ EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
}
TEST(MutableGraphViewTest, DeleteNodes) {
// Outputs simple graph as described in first test.
- TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"});
+ TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 85cb19d419..685b5379af 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -2489,6 +2489,11 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto tensors = EvaluateNodes(got, item.fetch);
EXPECT_EQ(7, tensors.size());
+ for (int i = 0; i < 7; ++i) {
+ EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
+ test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
+ }
+
GraphDef want;
AddNode("x", "Const", {}, {}, &want);
AddNode("y2", "Const", {}, {}, &want);
@@ -2534,6 +2539,11 @@ TEST_F(ArithmeticOptimizerTest, Log1p) {
auto tensors = EvaluateNodes(got, item.fetch);
EXPECT_EQ(2, tensors.size());
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
+ test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
+ }
+
GraphDef want;
AddNode("x1", "Const", {}, {}, &want);
AddNode("x2", "Const", {}, {}, &want);
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index d7ac58c99d..451ef6cabb 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -70,6 +70,26 @@ tf_cc_test(
)
cc_library(
+ name = "latency_all_edges",
+ srcs = ["latency_all_edges.cc"],
+ hdrs = [
+ "latency_all_edges.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "map_and_batch_fusion",
srcs = ["map_and_batch_fusion.cc"],
hdrs = [
@@ -213,6 +233,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":function_rename",
+ ":latency_all_edges",
":map_and_batch_fusion",
":map_fusion",
":noop_elimination",
@@ -220,3 +241,17 @@ cc_library(
],
alwayslink = 1,
)
+
+tf_cc_test(
+ name = "latency_all_edges_test",
+ srcs = ["latency_all_edges_test.cc"],
+ deps = [
+ ":graph_utils",
+ ":latency_all_edges",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 6ce6533369..838787d2a5 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -27,11 +27,17 @@ namespace {
constexpr char kConstOpName[] = "Const";
template <typename Predicate, typename Collection>
-int GetElementIdxWithPredicate(const Predicate& predicate,
- const Collection& collection) {
- auto it = std::find_if(collection.begin(), collection.end(), predicate);
- if (it == collection.end()) return -1;
- return std::distance(collection.begin(), it);
+std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ std::vector<int> indices = {};
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ indices.push_back(idx);
+ }
+ idx++;
+ }
+ return indices;
}
std::vector<int> CreateNameIndex(const GraphDef& graph) {
@@ -189,29 +195,39 @@ bool ContainsFunctionNodeWithName(const string& name,
}
int FindGraphNodeWithName(const string& name, const GraphDef& graph) {
- return GetElementIdxWithPredicate(
+ std::vector<int> indices = GetElementIndicesWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
+ return indices.empty() ? -1 : indices.front();
}
int FindNodeWithOp(const string& op, const GraphDef& graph) {
- return GetElementIdxWithPredicate(
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
+ return indices.empty() ? -1 : indices.front();
+}
+
+std::vector<int> FindAllGraphNodesWithOp(const string& op,
+ const GraphDef& graph) {
+ return GetElementIndicesWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
int FindGraphFunctionWithName(const string& name,
const FunctionDefLibrary& library) {
- return GetElementIdxWithPredicate(
+ std::vector<int> indices = GetElementIndicesWithPredicate(
[&name](const FunctionDef& function) {
return function.signature().name() == name;
},
library.function());
+ return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithName(const string& name, const FunctionDef& function) {
- return GetElementIdxWithPredicate(
+ std::vector<int> indices = GetElementIndicesWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
function.node_def());
+ return indices.empty() ? -1 : indices.front();
}
void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph,
@@ -219,7 +235,12 @@ void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph,
string name = prefix;
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
- name = strings::StrCat(prefix, "/_", id);
+ if (name.rfind("_generated") != std::string::npos &&
+ (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
+ name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
+ } else {
+ name = strings::StrCat(prefix, "/_", id);
+ }
++id;
}
node->set_name(std::move(name));
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 0847748802..39c687b501 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -90,10 +90,15 @@ int FindGraphFunctionWithName(const string& name,
// function node does not exist.
int FindFunctionNodeWithName(const string& name, const FunctionDef& function);
-// Returns the index of a node with the given op or -1 if no such node
+// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindNodeWithOp(const string& op, const GraphDef& graph);
+// Returns the list of indices of all nodes with the given op or empty list if
+// no such node exists.
+std::vector<int> FindAllGraphNodesWithOp(const string& op,
+ const GraphDef& graph);
+
// Sets the node name using `prefix` as a prefix while guaranteeing the name
// is unique across the graph.
void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 59ed79ab8f..e6789d47b5 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -167,10 +167,34 @@ TEST(GraphUtilsTest, FindNodeWithOp) {
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
- EXPECT_NE(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+ AddNode("B", "OpB", {"A"}, {}, &graph);
+ AddNode("A2", "OpA", {"B"}, {}, &graph);
+ EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0);
- graph.DeleteNodes({"A"});
+ graph.DeleteNodes({"B"});
+ EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
+}
+
+TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+
+ AddNode("A", "OpA", {}, {}, &graph);
+ AddNode("B", "OpB", {"A"}, {}, &graph);
+ AddNode("A2", "OpA", {"B"}, {}, &graph);
+ std::vector<int> result_indices =
+ FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
+ EXPECT_EQ(result_indices.size(), 2);
+ EXPECT_EQ(result_indices.at(0), 0);
+ EXPECT_EQ(result_indices.at(1), 2);
+
+ graph.DeleteNodes({"A2"});
+ std::vector<int> result_indices_new =
+ FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
+ EXPECT_EQ(result_indices_new.size(), 1);
+ EXPECT_EQ(result_indices_new.at(0), 0);
}
TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
new file mode 100644
index 0000000000..0b25b1ea9d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/latency_all_edges.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char kInsertOpName[] = "LatencyStatsDataset";
+
+NodeDef make_latency_node(const NodeDef& node, MutableGraphView* graph) {
+ NodeDef new_node;
+ new_node.set_op(kInsertOpName);
+ graph_utils::SetUniqueGraphNodeName(
+ strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(),
+ &new_node);
+ // Set the input of LatencyDataset node as `node`
+ new_node.add_input(node.name());
+
+ NodeDef* tag = graph_utils::AddScalarConstNode<StringPiece>(
+ StringPiece("record_latency_" + node.name()), graph);
+ new_node.add_input(tag->name());
+
+ // Set `output_types` and `output_shapes` attributes.
+ for (auto key : {"output_shapes", "output_types"}) {
+ if (node.attr().find(key) != node.attr().end()) {
+ (*new_node.mutable_attr())[key] = node.attr().at(key);
+ } else {
+ const char* kInferredAttrPrefix = "T";
+ if (node.attr().find(strings::StrCat(kInferredAttrPrefix, key)) !=
+ node.attr().end()) {
+ (*new_node.mutable_attr())[key] =
+ node.attr().at(strings::StrCat(kInferredAttrPrefix, key));
+ }
+ }
+ }
+ return new_node;
+}
+
+} // namespace
+
+Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+
+ // Add LatencyDatasetOp node after each node.
+ // TODO(shivaniagrawal): Add Op to return Latency for the particular Op than
+ // for the edge (e2 - e1?).
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op().rfind("Dataset") != node.op().size() - strlen("Dataset") ||
+ node.attr().empty() ||
+ node.name().rfind("_generated") ==
+ node.name().size() - strlen("_generated")) {
+ // TODO(b/111805951): Replace this with non-approximate way to check if
+ // node corresponds to a `Dataset` op.
+ continue;
+ }
+ GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0);
+ auto fanout = graph.GetFanout(output_port);
+ if (fanout.size() > 1) {
+ LOG(WARNING) << node.name() << " has fanout size " << fanout.size();
+ continue;
+ } else { // fanout will have size 0 for last dataset node in the pipeline.
+ if (fanout.size() == 1) {
+ NodeDef* output_node = (*(fanout.begin())).node;
+ if (output_node->name().rfind("_generated") ==
+ output_node->name().size() - strlen("_generated")) {
+ continue;
+ }
+ }
+ }
+
+ graph.InsertNode(node, make_latency_node(node, &graph));
+ }
+ return Status::OK();
+}
+
+void LatencyAllEdges::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(LatencyAllEdges, "latency_all_edges");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.h b/tensorflow/core/grappler/optimizers/data/latency_all_edges.h
new file mode 100644
index 0000000000..f6c71a9ec7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class LatencyAllEdges : public CustomGraphOptimizer {
+ public:
+ LatencyAllEdges() = default;
+ ~LatencyAllEdges() override = default;
+
+ string name() const override { return "latency_all_edges"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc
new file mode 100644
index 0000000000..6789cf5bd6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/latency_all_edges.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(LatencyAllEdgesTest, AddLatenciesAfterTensorMapPrefetch) {
+ using test::function::NDef;
+ GrapplerItem item;
+ NodeDef component_node =
+ NDef("component_nodes", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}});
+ NodeDef from_tensor_node =
+ NDef("from_tensor_nodes", "TensorDataset", {"component_nodes"},
+ {{"Toutput_types", {}}, {"output_shapes", {}}});
+
+ NodeDef captured_input_node = NDef("captured_input_node", "Const", {},
+ {{"value", ""}, {"dtype", DT_STRING}});
+ NodeDef map_node = NDef("map_node", "MapDataset",
+ {"from_tensor_node", "captured_input_node"},
+ {{"f", {}},
+ {"Targumemts", {}},
+ {"output_shapes", {}},
+ {"output_types", {}}});
+ NodeDef buffer_size_node = NDef("buffer_size_node", "Const", {},
+ {{"value", 1}, {"dtype", DT_INT32}});
+ NodeDef prefetch_node = NDef("prefetch_node", "Prefetch_Dataset",
+ {"map_node", "buffer_size_node"},
+ {{"output_shapes", {}}, {"output_types", {}}});
+
+ item.graph = test::function::GDef({component_node, from_tensor_node,
+ captured_input_node, map_node,
+ buffer_size_node, prefetch_node});
+
+ LatencyAllEdges optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("LatencyStatsDataset", output));
+ std::vector<int> latency_node_indices =
+ graph_utils::FindAllGraphNodesWithOp("LatencyStatsDataset", output);
+ EXPECT_EQ(latency_node_indices.size(), 3);
+ std::vector<NodeDef> dataset_nodes = {std::move(from_tensor_node),
+ std::move(map_node),
+ std::move(prefetch_node)};
+ for (int i = 0; i < latency_node_indices.size(); i++) {
+ NodeDef latency_node = output.node(latency_node_indices[i]);
+ EXPECT_EQ(latency_node.input_size(), 2);
+ EXPECT_EQ(latency_node.input(0), dataset_nodes[i].name());
+ EXPECT_TRUE(
+ AreAttrValuesEqual(latency_node.attr().at("output_shapes"),
+ dataset_nodes[i].attr().at("output_shapes")));
+ if (dataset_nodes[i].attr().find("output_types") !=
+ dataset_nodes[i].attr().end()) {
+ EXPECT_TRUE(
+ AreAttrValuesEqual(latency_node.attr().at("output_types"),
+ dataset_nodes[i].attr().at("output_types")));
+ } else {
+ if (dataset_nodes[i].attr().find("Toutput_types") !=
+ dataset_nodes[i].attr().end()) {
+ EXPECT_TRUE(
+ AreAttrValuesEqual(latency_node.attr().at("output_types"),
+ dataset_nodes[i].attr().at("Toutput_types")));
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index 223efd2670..f3a07be728 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -464,7 +464,25 @@ std::vector<int> GetStackPushNodesToConvert(
const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
if (IsStackPushOp(fanout_node)) {
- nodes_to_convert.push_back(fanout_idx);
+ // Check that the stack itself is not a node we want to preserve. This can
+ // happen when the graph we have contains only the forward pass for a loop
+ // (as when the forward and backward passes are split across different
+ // functions).
+ if (graph_view.has_node(fanout_node.input(0))) {
+ const NodeDef* stack_node =
+ &graph_view.node(graph_view.index(fanout_node.input(0)));
+ while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" &&
+ stack_node->input_size() > 0 &&
+ graph_view.has_node(stack_node->input(0))) {
+ stack_node = &graph_view.node(graph_view.index(stack_node->input(0)));
+ }
+ if (nodes_to_preserve.find(stack_node->name()) ==
+ nodes_to_preserve.end()) {
+ nodes_to_convert.push_back(fanout_idx);
+ }
+ } else {
+ nodes_to_convert.push_back(fanout_idx);
+ }
} else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
op_types_to_traverse.find(fanout_node.op()) !=
op_types_to_traverse.end()) {
@@ -549,7 +567,6 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
const NodeMap& node_map,
- bool is_optimization_aggressive,
DeviceBase* cpu_device, ResourceMgr* resource_mgr,
bool* has_dead_fanout, int* dead_fanout) {
*has_dead_fanout = false;
@@ -571,7 +588,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
// We check if its a while loop such that the condition is a simple binary
// operator which returns false for the initialization value.
// TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs.
- if (!is_optimization_aggressive || !IsMerge(*switch_input)) {
+ if (!IsMerge(*switch_input)) {
return Status::OK();
}
@@ -604,27 +621,27 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
if (merge_node == nullptr || constant_ctrl_input == nullptr) {
return Status::OK();
}
- // Find Enter.
- // TODO(srjoglekar): Reconcile this with the optimization in
- // ConstantFolding::MoveConstantsPastEnter
+ // Find the initialization constant (via Enter, if one exists).
NodeDef* enter_node = nullptr;
+ NodeDef* constant_init_node = nullptr;
for (const auto& input : merge_node->input()) {
NodeDef* node = node_map.GetNode(input);
if (IsEnter(*node)) {
enter_node = node;
}
- }
- if (enter_node == nullptr) {
- return Status::OK();
- }
- // Find the initialization constant.
- NodeDef* constant_init_node = nullptr;
- for (const auto& input : enter_node->input()) {
- NodeDef* node = node_map.GetNode(input);
if (IsConstant(*node)) {
constant_init_node = node;
}
}
+ if (enter_node != nullptr) {
+ if (constant_init_node != nullptr) return Status::OK();
+ for (const auto& input : enter_node->input()) {
+ NodeDef* node = node_map.GetNode(input);
+ if (IsConstant(*node)) {
+ constant_init_node = node;
+ }
+ }
+ }
if (constant_init_node == nullptr) {
return Status::OK();
}
@@ -704,9 +721,9 @@ Status LoopOptimizer::RemoveDeadBranches(
int dead_fanout;
bool has_dead_fanout;
- TF_RETURN_IF_ERROR(CheckForDeadFanout(
- view, node, node_map, opt_level_ == RewriterConfig::AGGRESSIVE,
- cpu_device_, resource_mgr_.get(), &has_dead_fanout, &dead_fanout));
+ TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, cpu_device_,
+ resource_mgr_.get(), &has_dead_fanout,
+ &dead_fanout));
if (!has_dead_fanout) {
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
index f5fe28d4ba..81f40db8f0 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
@@ -536,6 +536,29 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
+TEST_F(LoopOptimizerTest, RemovePush_NoPopButStackLives) {
+ GrapplerItem item;
+ GraphDef& graph = item.graph;
+ AddSimpleNode("c", "Const", {}, &graph);
+ // Stack with corresponding push
+ AddSimpleNode("stack1", "StackV2", {}, &graph);
+ AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
+ // Stack with corresponding push behind Enter.
+ AddSimpleNode("stack2", "StackV2", {}, &graph);
+ AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
+ AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
+ AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
+ item.keep_ops.push_back("stack1");
+ item.keep_ops.push_back("stack2");
+
+ LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ VerifyGraphsEqual(item.graph, output, __FUNCTION__);
+}
+
TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
GrapplerItem item;
GraphDef& graph = item.graph;
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index b297caa8d4..a9c34b6d08 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -239,6 +239,9 @@ class SimpleGraphView {
const GraphDef* graph() const { return graph_; }
inline int num_nodes() const { return index_to_name_.size(); }
+ inline bool has_node(const string& node_name) const {
+ return name_to_index_.find(node_name) != name_to_index_.end();
+ }
inline const int index(const string& node_name) const {
const auto& it = name_to_index_.find(node_name);
DCHECK(it != name_to_index_.end());
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index d64cb49715..fd71406d2c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -119,7 +119,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
if (Scanner(remaining)
.OneLiteral(":")
.RestartCapture()
- .One(strings::Scanner::LOWERLETTER)
+ .One(strings::Scanner::LETTER)
.Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
.GetResult(&remaining, &capture)) {
node_output = string(capture.data(), capture.size());
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index ff89035902..63ca92c69e 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include <algorithm>
#include <deque>
#include <unordered_map>
#include "tensorflow/core/framework/node_def.pb.h"
@@ -85,6 +86,14 @@ Status ComputeTopologicalOrder(
return Status::OK();
}
+Status ReversedTopologicalSort(GraphDef* graph) {
+ std::vector<int> ready_nodes;
+ TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
+ std::reverse(ready_nodes.begin(), ready_nodes.end());
+ PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
+ return Status::OK();
+}
+
Status TopologicalSort(GraphDef* graph) {
std::vector<int> ready_nodes;
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h
index bc0299a7b8..b8cf897a32 100644
--- a/tensorflow/core/grappler/utils/topological_sort.h
+++ b/tensorflow/core/grappler/utils/topological_sort.h
@@ -31,6 +31,9 @@ Status ComputeTopologicalOrder(
// Sort a graph in topological order.
Status TopologicalSort(GraphDef* graph);
+// Sort a graph in topological order and reverse it.
+Status ReversedTopologicalSort(GraphDef* graph);
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 6126e8b7ba..82fab76c49 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -782,7 +782,7 @@ tf_kernel_library(
tf_kernel_library(
name = "quantize_and_dequantize_op",
prefix = "quantize_and_dequantize_op",
- deps = ARRAY_DEPS,
+ deps = ARRAY_DEPS + [":cwise_op"],
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 90bc3da0cd..e9175e768d 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -554,6 +554,7 @@ tf_kernel_library(
deps = [
":dataset",
":dataset_utils",
+ ":optional_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -566,6 +567,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "optional_ops",
+ srcs = ["optional_ops.cc"],
+ hdrs = ["optional_ops.h"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "cache_dataset_ops",
srcs = ["cache_dataset_ops.cc"],
deps = [
@@ -631,6 +646,7 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":optimize_dataset_op",
+ ":optional_ops",
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index ed4932bf32..86b0840aea 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -39,7 +39,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument<string>(ctx, "filename", &filename));
if (filename.empty()) {
- *output = new MemoryDataset(input);
+ *output = new MemoryDataset(ctx, input);
} else {
*output = new FileDataset(ctx, input, filename, ctx->env());
}
@@ -68,8 +68,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new FileCacheIterator(
- {this, strings::StrCat(prefix, "::FileCacheIterator")}));
+ return std::unique_ptr<IteratorBase>(
+ new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -105,9 +105,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
tensor_index);
}
- class FileCacheIterator : public DatasetIterator<FileDataset> {
+ class FileIterator : public DatasetIterator<FileDataset> {
public:
- explicit FileCacheIterator(const Params& params)
+ explicit FileIterator(const Params& params)
: DatasetIterator<FileDataset>(params) {
if (params.dataset->env_
->FileExists(MetaFilename(params.dataset->filename_))
@@ -526,7 +526,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
enum Mode { read, write };
Mode mode_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
- }; // FileCacheIterator
+ }; // FileIterator
const DatasetBase* const input_;
const string filename_;
@@ -538,9 +538,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
const string tensor_format_string_;
}; // FileDataset
- class MemoryDataset : public DatasetBase {
+ class MemoryDataset : public GraphDatasetBase {
public:
- explicit MemoryDataset(const DatasetBase* input) : input_(input) {
+ explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
+ : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) {
input->Ref();
}
@@ -548,18 +549,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- mutex_lock l(mu_);
- if (cache_) {
- return std::unique_ptr<IteratorBase>(new MemoryReaderIterator(
- {this, strings::StrCat(prefix, "::MemoryReader")}, cache_.get()));
- }
- if (!writer_iterator_created_) {
- writer_iterator_created_ = true;
- return std::unique_ptr<IteratorBase>(new MemoryWriterIterator(
- {this, strings::StrCat(prefix, "::MemoryWriter")}));
- }
- return std::unique_ptr<IteratorBase>(new DuplicateWriterIterator(
- {this, strings::StrCat(prefix, "::DuplicateWriter")}));
+ return std::unique_ptr<IteratorBase>(new MemoryIterator(
+ {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
@@ -574,114 +565,321 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return "CacheDatasetOp::MemoryDataset";
}
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ Node* filename_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_node, filename_node}, output));
+ return Status::OK();
+ }
+
private:
- // MemoryWriterIterator passes through and appends items from the input
- // dataset to its vector.
+ // A thread-safe data structure for caching dataset elements.
//
- // This iterator is used when dataset->cache_ is null. After buffering
- // the tensors in memory, upon exhausing the underlying iterator, they are
- // updated into the parent dataset's cache_ pointer.
- class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
+ // The expected use is that a single `MemoryWriterIterator` populates the
+ // cache with dataset elements. Once all elements are cached, the cache can
+ // be used by one or more `MemoryReaderIterator`s.
+ class MemoryCache {
public:
- explicit MemoryWriterIterator(const Params& params)
- : DatasetIterator<MemoryDataset>(params),
- cache_(new std::vector<std::vector<Tensor>>) {}
+ MemoryCache() = default;
- ~MemoryWriterIterator() override {
+ // Marks the cache as completed.
+ void Complete() {
mutex_lock l(mu_);
- if (cache_) {
- LOG(ERROR)
- << "The calling iterator did not fully read the dataset we were "
- "attempting to cache. In order to avoid unexpected truncation "
- "of the sequence, the current [partially cached] sequence "
- "will be dropped. This can occur if you have a sequence "
- "similar to `dataset.cache().take(k).repeat()`. Instead, swap "
- "the order (i.e. `dataset.take(k).cache().repeat()`)";
- mutex_lock l2(dataset()->mu_);
- dataset()->writer_iterator_created_ = false;
- }
+ completed_ = true;
}
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ // Returns whether the cache is claimed.
+ bool IsClaimed() {
+ tf_shared_lock l(mu_);
+ return claimed_;
}
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
+ // Returns whether the cache is completed.
+ bool IsCompleted() {
+ tf_shared_lock l(mu_);
+ return completed_;
+ }
+
+ // Attempts to claim the cache, returning whether the cache was claimed.
+ bool MaybeClaim() {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (*end_of_sequence) {
- // Guard on cache_ to not crash if GetNext is called a second time
- // after *end_of_sequence == true
- if (cache_) {
- mutex_lock l(dataset()->mu_);
- DCHECK(dataset()->writer_iterator_created_);
- DCHECK(!dataset()->cache_);
- cache_.swap(dataset()->cache_);
- }
- return Status::OK();
+ if (!claimed_) {
+ claimed_ = true;
+ return true;
}
- cache_->emplace_back(*out_tensors);
- return Status::OK();
+ return false;
+ }
+
+ // Resets the cache.
+ void Reset() {
+ mutex_lock l(mu_);
+ claimed_ = false;
+ completed_ = false;
+ cache_.clear();
+ }
+
+ // Returns the element at the given index.
+ const std::vector<Tensor>& at(int64 index) {
+ tf_shared_lock l(mu_);
+ DCHECK(index < cache_.size());
+ return cache_[index];
+ }
+
+ // Adds the element to the cache.
+ void emplace_back(std::vector<Tensor> element) {
+ mutex_lock l(mu_);
+ cache_.emplace_back(std::move(element));
+ }
+
+ // Returns the size of the cache.
+ size_t size() {
+ tf_shared_lock l(mu_);
+ return cache_.size();
}
private:
mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- std::unique_ptr<std::vector<std::vector<Tensor>>> cache_ GUARDED_BY(mu_);
- }; // MemoryWriterIterator
-
- class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
+ // Determines whether a writer has claimed the cache.
+ bool claimed_ GUARDED_BY(mu_) = false;
+ // Determines whether all elements of the dataset have been cached.
+ bool completed_ GUARDED_BY(mu_) = false;
+ std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
+ };
+
+ class MemoryIterator : public DatasetIterator<MemoryDataset> {
public:
- explicit MemoryReaderIterator(
- const Params& params, const std::vector<std::vector<Tensor>>* cache)
- : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
- CHECK(cache);
+ explicit MemoryIterator(const Params& params,
+ const std::shared_ptr<MemoryCache>& cache)
+ : DatasetIterator<MemoryDataset>(params), cache_(cache) {
+ mode_ = cache->MaybeClaim() ? Mode::write : Mode::read;
+ InitializeIterator();
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (mode_ == Mode::read && !cache_->IsCompleted()) {
+ return errors::Internal(
+ "Cache should only be read after it has been completed.");
+ }
+ return iterator_->Initialize(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- if (index_ < cache_->size()) {
- const std::vector<Tensor>& cache_tensors = (*cache_)[index_];
- out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
- cache_tensors.end());
- index_++;
- *end_of_sequence = false;
- return Status::OK();
- } else {
- *end_of_sequence = true;
- return Status::OK();
+ return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_));
+ if (cache_->IsClaimed()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cache_claimed"), ""));
+ size_t cache_size = cache_->size();
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cache_size"), cache_size));
+ for (size_t i = 0; i < cache_size; i++) {
+ auto& element = cache_->at(i);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("cache[", i, "].size")),
+ element.size()));
+ for (size_t j = 0; j < element.size(); ++j) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("cache[", i, "][", j, "]")),
+ element[j]));
+ }
+ }
+ if (cache_->IsCompleted()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cache_completed"), ""));
+ }
}
+ return SaveParent(writer, iterator_);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ iterator_.reset();
+ cache_->Reset();
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp));
+ mode_ = static_cast<Mode>(temp);
+ }
+ if (reader->Contains(full_name("cache_claimed"))) {
+ CHECK(cache_->MaybeClaim());
+ size_t cache_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cache_size"), &temp));
+ cache_size = static_cast<size_t>(temp);
+ }
+ for (size_t i = 0; i < cache_size; ++i) {
+ std::vector<Tensor> element;
+ size_t element_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("cache[", i, "].size")), &temp));
+ element_size = static_cast<size_t>(temp);
+ }
+ element.reserve(element_size);
+ for (size_t j = 0; j < element_size; ++j) {
+ element.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("cache[", i, "][", j, "]")),
+ &element.back()));
+ }
+ cache_->emplace_back(std::move(element));
+ }
+ if (reader->Contains(full_name("cache_completed"))) {
+ cache_->Complete();
+ }
+ }
+ InitializeIterator();
+ TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
+ return RestoreParent(ctx, reader, iterator_);
}
private:
- mutex mu_;
- const std::vector<std::vector<Tensor>>* const cache_;
- size_t index_ GUARDED_BY(mu_);
- }; // MemoryReaderIterator
+ class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
+ public:
+ explicit MemoryWriterIterator(const Params& params,
+ const std::shared_ptr<MemoryCache>& cache)
+ : DatasetIterator<MemoryDataset>(params), cache_(cache) {
+ CHECK(cache_);
+ }
- class DuplicateWriterIterator : public DatasetIterator<MemoryDataset> {
- public:
- explicit DuplicateWriterIterator(const Params& params)
- : DatasetIterator<MemoryDataset>(params) {}
+ ~MemoryWriterIterator() override {
+ mutex_lock l(mu_);
+ if (cache_->size() > 0 && !cache_->IsCompleted()) {
+ LOG(WARNING)
+ << "The calling iterator did not fully read the dataset being "
+ "cached. In order to avoid unexpected truncation of the "
+ "dataset, the partially cached contents of the dataset"
+ "will be discarded. This can happen if you have an input "
+ "pipeline similar to `dataset.cache().take(k).repeat()`. "
+ "You should use `dataset.take(k).cache().repeat()` instead.";
+ cache_->Reset();
+ }
+ }
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- return errors::AlreadyExists(
- "There appears to be a concurrent caching iterator running.");
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ cache_->Complete();
+ return Status::OK();
+ }
+ cache_->emplace_back(*out_tensors);
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ return SaveParent(writer, input_impl_);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ return RestoreParent(ctx, reader, input_impl_);
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::shared_ptr<MemoryCache> cache_;
+ }; // MemoryWriterIterator
+
+ class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
+ public:
+ explicit MemoryReaderIterator(const Params& params,
+ const std::shared_ptr<MemoryCache>& cache)
+ : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
+ CHECK(cache);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp));
+ index_ = static_cast<size_t>(temp);
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ < cache_->size()) {
+ const std::vector<Tensor>& cache_tensors = cache_->at(index_);
+ out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
+ cache_tensors.end());
+ index_++;
+ *end_of_sequence = false;
+ return Status::OK();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+
+ private:
+ mutex mu_;
+ const std::shared_ptr<MemoryCache> cache_;
+ size_t index_ GUARDED_BY(mu_);
+ }; // MemoryReaderIterator
+
+ void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ switch (mode_) {
+ case Mode::read:
+ iterator_.reset(
+ new MemoryReaderIterator({dataset(), prefix()}, cache_));
+ break;
+ case Mode::write:
+ iterator_.reset(
+ new MemoryWriterIterator({dataset(), prefix()}, cache_));
+ }
}
- }; // DuplicateWriterIterator
+
+ mutex mu_;
+ std::shared_ptr<MemoryCache> cache_;
+ enum Mode { read, write };
+ Mode mode_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
+ }; // MemoryIterator
const DatasetBase* const input_;
- mutable mutex mu_;
- mutable std::unique_ptr<std::vector<std::vector<Tensor>>> cache_
- GUARDED_BY(mu_);
- mutable bool writer_iterator_created_ GUARDED_BY(mu_) = false;
+ const std::shared_ptr<MemoryCache> cache_;
}; // MemoryDataset
}; // CacheDatasetOp
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index da489db7c8..86adbc4f47 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -1084,6 +1085,86 @@ class IteratorGetNextSyncOp : public OpKernel {
}
};
+class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
+ public:
+ explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(
+ ctx->env(), strings::StrCat("iterator_get_next_as_optional_thread_",
+ SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule(std::bind(
+ [this, ctx, iterator](DoneCallback done) {
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
+
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ Status s =
+ iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ // NOTE(mrry): We must unref the iterator before calling `done()`, to
+ // avoid destruction races.
+ iterator->Unref();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (end_of_sequence) {
+ OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
+ } else {
+ for (int i = 0; i < components.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, components[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The given optional does not match the expected type for "
+ "component ",
+ i, ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(components[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx,
+ output_shapes_[i].IsCompatibleWith(components[i].shape()),
+ errors::InvalidArgument(
+ "The given optional does not match the expected shape "
+ "for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", components[i].shape().DebugString(), "."),
+ done);
+ }
+
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
+ done);
+ }
+ done();
+ },
+ std::move(done)));
+ }
+
+ private:
+ BackgroundWorker background_worker_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
class IteratorToStringHandleOp : public OpKernel {
public:
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
@@ -1240,6 +1321,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
IteratorGetNextSyncOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU),
IteratorGetNextSyncOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_CPU),
+ IteratorGetNextAsOptionalOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_GPU),
+ IteratorGetNextAsOptionalOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
new file mode 100644
index 0000000000..cfac45dbc7
--- /dev/null
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -0,0 +1,270 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/optional_ops.h"
+
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+
+namespace tensorflow {
+namespace {
+const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
+
+// An `OptionalVariant` can represent either an "actual value" (a tuple of
+// tensors) or "none", and may be stored in a DT_VARIANT tensor.
+class OptionalVariant {
+ public:
+ // Create an `OptionalVariant` with no actual value.
+ OptionalVariant() : values_(nullptr) {}
+
+ // Create an `OptionalVariant` with the actual value given by the tuple of
+ // tensors in `values`.
+ explicit OptionalVariant(std::vector<Tensor> values)
+ : values_(new std::vector<Tensor>(std::move(values))) {}
+
+ OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
+
+ // Returns true if `this` represents an actual value.
+ bool has_value() const { return values_ != nullptr; }
+
+ // REQUIRES: `this->has_value()` must be true.
+ const std::vector<Tensor>& get_values() const {
+ CHECK(values_) << "Tried to get values from an empty OptionalVariant";
+ return *values_;
+ }
+
+ // Implementations of the necessary methods for using `OptionalVariant`
+ // objects in DT_VARIANT tensors.
+ string TypeName() const { return kOptionalVariantTypeName; }
+ void Encode(VariantTensorData* data) const {
+ data->set_metadata(values_ != nullptr);
+ if (values_ != nullptr) {
+ for (const auto& t : *values_) {
+ *(data->add_tensors()) = t;
+ }
+ }
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ if (data.type_name() != TypeName()) {
+ return false;
+ }
+ bool has_value = false;
+ if (!data.get_metadata(&has_value)) {
+ return false;
+ }
+ if (has_value) {
+ values_.reset(new std::vector<Tensor>(data.tensors()));
+ } else {
+ values_.reset();
+ }
+ return true;
+ }
+
+ string DebugString() const {
+ if (values_) {
+ return strings::StrCat("OptionalVariant<", "values: (",
+ str_util::Join(*values_, ", ",
+ [](string* s, const Tensor& elem) {
+ *s = elem.DebugString();
+ }),
+ ")>");
+ } else {
+ return strings::StrCat("OptionalVariant<None>");
+ }
+ }
+
+ private:
+ std::shared_ptr<const std::vector<Tensor>> values_;
+};
+
+class OptionalNoneOp : public OpKernel {
+ public:
+ explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0));
+ }
+};
+
+class OptionalFromValueOp : public OpKernel {
+ public:
+ explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ OpInputList components_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
+ std::vector<Tensor> components;
+ components.reserve(components_input.size());
+ for (const Tensor& component_t : components_input) {
+ components.push_back(component_t);
+ }
+ OP_REQUIRES_OK(
+ ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
+ }
+};
+
+class OptionalHasValueOp : public OpKernel {
+ public:
+ explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* optional_input;
+ OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be a scalar tensor "
+ "containing an OptionalVariant object."));
+ const OptionalVariant* optional =
+ optional_input->scalar<Variant>()().get<OptionalVariant>();
+ OP_REQUIRES(
+ ctx, optional != nullptr,
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be an OptionalVariant object."));
+ Tensor* result;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
+ result->scalar<bool>()() = optional->has_value();
+ }
+};
+
+class OptionalGetValueOp : public OpKernel {
+ public:
+ explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* optional_input;
+ OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be a scalar tensor "
+ "containing an OptionalVariant object."));
+ const OptionalVariant* optional =
+ optional_input->scalar<Variant>()().get<OptionalVariant>();
+ OP_REQUIRES(
+ ctx, optional != nullptr,
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be an OptionalVariant object."));
+ OP_REQUIRES(
+ ctx, optional->has_value(),
+ errors::InvalidArgument("The given optional does not have a value."));
+ const auto& components = optional->get_values();
+ for (int i = 0; i < components.size(); ++i) {
+ OP_REQUIRES(
+ ctx, components[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The given optional does not match the expected type for "
+ "component ",
+ i, ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(components[i].dtype()), "."));
+ OP_REQUIRES(ctx,
+ output_shapes_[i].IsCompatibleWith(components[i].shape()),
+ errors::InvalidArgument(
+ "The given optional does not match the expected shape "
+ "for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", components[i].shape().DebugString(), "."));
+ ctx->set_output(i, components[i]);
+ }
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU),
+ OptionalNoneOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU),
+ OptionalNoneOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_CPU),
+ OptionalFromValueOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_GPU),
+ OptionalFromValueOp);
+
+REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU),
+ OptionalHasValueOp);
+REGISTER_KERNEL_BUILDER(
+ Name("OptionalHasValue").Device(DEVICE_GPU).HostMemory("has_value"),
+ OptionalHasValueOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU),
+ OptionalGetValueOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU),
+ OptionalGetValueOp);
+
+static Status OptionalDeviceCopy(
+ const OptionalVariant& from, OptionalVariant* to,
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
+ if (from.has_value()) {
+ const std::vector<Tensor>& from_values = from.get_values();
+ std::vector<Tensor> to_values;
+ to_values.reserve(from_values.size());
+ for (const Tensor& t : from_values) {
+ if (DMAHelper::CanUseDMA(&t)) {
+ Tensor tmp(t.dtype());
+ TF_RETURN_IF_ERROR(copy(t, &tmp));
+ to_values.push_back(std::move(tmp));
+ } else {
+ to_values.push_back(t);
+ }
+ }
+ *to = OptionalVariant(std::move(to_values));
+ } else {
+ *to = from;
+ }
+ return Status::OK();
+}
+
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
+ OptionalDeviceCopy)
+
+REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
+REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
+REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
+
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
+ kOptionalVariantTypeName);
+
+} // namespace
+
+Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
+ std::vector<Tensor> value) {
+ OptionalVariant v(std::move(value));
+ Tensor* variant_t;
+ AllocatorAttributes cpu_alloc;
+ cpu_alloc.set_on_host(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
+ &variant_t, cpu_alloc));
+ variant_t->scalar<Variant>()() = v;
+ return Status::OK();
+}
+
+Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
+ OptionalVariant v;
+ Tensor* variant_t;
+ AllocatorAttributes cpu_alloc;
+ cpu_alloc.set_on_host(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
+ &variant_t, cpu_alloc));
+ variant_t->scalar<Variant>()() = v;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
new file mode 100644
index 0000000000..6f25567678
--- /dev/null
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+
+namespace tensorflow {
+
+// Stores a DT_VARIANT value representing an Optional with the given value
+// in the `output_index`^th output of the given kernel execution context.
+Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
+ std::vector<Tensor> value);
+
+// Stores a DT_VARIANT value representing an Optional with no value
+// in the `output_index`^th output of the given kernel execution context.
+Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index cb285bf732..1c0abf26cd 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -127,31 +127,47 @@ class IfOp : public AsyncOpKernel {
explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
- const NameAttrList* func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func));
- OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func));
- OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
}
~IfOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ auto lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library"), done);
+
+ // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
+ // registration, this kernel may be shared by multiple subgraphs, which have
+ // different associated `FunctionLibraryRuntime` objects and hence different
+ // `FHandle` namespaces. So we must call Instantiate() to make sure we get
+ // the correct function handles with respect to `lib`. Note the underlying
+ // `lib->Instantiate()` caches the created function handles, so calling
+ // `Instantiate()` repeatedly on the same `lib` and function is cheap.
+ FHandle then_handle;
+ FHandle else_handle;
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done);
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done);
+
bool cond;
OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
- (new State(this, ctx, cond, done))->Start();
+ (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
}
private:
- FHandle then_handle_;
- FHandle else_handle_;
+ NameAttrList then_func_;
+ NameAttrList else_func_;
class State {
public:
- State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done)
+ State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
+ FHandle else_handle, DoneCallback done)
: kernel_(kernel),
ctx_(ctx),
cond_(cond),
+ then_handle_(then_handle),
+ else_handle_(else_handle),
done_(std::move(done)),
lib_(CHECK_NOTNULL(ctx_->function_library())) {
SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
@@ -163,7 +179,7 @@ class IfOp : public AsyncOpKernel {
~State() {}
void Start() {
- FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_;
+ FHandle handle = cond_ ? then_handle_ : else_handle_;
rets_.clear();
lib_->Run(
// Evaluate one of the branch.
@@ -184,6 +200,8 @@ class IfOp : public AsyncOpKernel {
IfOp* const kernel_;
OpKernelContext* const ctx_;
const bool cond_;
+ FHandle then_handle_;
+ FHandle else_handle_;
DoneCallback done_;
FunctionLibraryRuntime* const lib_;
FunctionLibraryRuntime::Options opts_;
@@ -214,30 +232,17 @@ class WhileOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library"), done);
- // TODO(b/37549631): Because this op has `SetIsStateful()` in its
- // op registration, this kernel may be shared by multiple
- // subgraphs, which have different associated
- // `FunctionLibraryRuntime` objects and hence different `FHandle`
- // namespaces. We currently work around this by caching the map
- // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
- // functions this op uses.
+ // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
+ // registration, this kernel may be shared by multiple subgraphs, which have
+ // different associated `FunctionLibraryRuntime` objects and hence different
+ // `FHandle` namespaces. So we must call Instantiate() to make sure we get
+ // the correct function handles with respect to `lib`. Note the underlying
+ // `lib->Instantiate()` caches the created function handles, so calling
+ // `Instantiate()` repeatedly on the same `lib` and function is cheap.
FHandle cond_handle;
FHandle body_handle;
- {
- mutex_lock l(mu_);
- const auto iter = handles_.find(lib);
- if (iter == handles_.end()) {
- OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle),
- done);
- OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle),
- done);
- handles_[lib] = {cond_handle, body_handle};
- } else {
- cond_handle = iter->second.first;
- body_handle = iter->second.second;
- }
- }
-
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
(new State(this, ctx, cond_handle, body_handle, done))->Start();
}
@@ -245,10 +250,6 @@ class WhileOp : public AsyncOpKernel {
NameAttrList cond_func_;
NameAttrList body_func_;
- mutex mu_;
- std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
- handles_ GUARDED_BY(mu_);
-
class State {
public:
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h
index 782263e4e9..6b0c5e5a46 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.h
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
namespace tensorflow {
namespace functor {
@@ -89,17 +90,14 @@ struct QuantizeAndDequantizeOneScaleImpl {
// min_range and max_range - because we may have changed either min_range
// or max_range.
out.device(d) =
- ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) * scale +
- T(0.5))
- .floor() *
- inverse_scale +
- min_range;
+ (input.cwiseMin(max_range).cwiseMax(min_range) * scale)
+ .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) *
+ inverse_scale;
} else {
- // No need to clamp to min_range and max_range in this case as they were
- // measured from the tensor.
out.device(d) =
- ((input - min_range) * scale + T(0.5)).floor() * inverse_scale +
- min_range;
+ (input * scale)
+ .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) *
+ inverse_scale;
}
}
};
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
index 629c698503..cddabf8a99 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
@@ -226,13 +226,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given) {
AddInputFromArray<float>(TensorShape({}), {1.0}); // Max
// Note that the range is given as [-1, 1].
- // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128,
+ // With int8, the tensor is quantized to {-102, -64, 0, 38, 102, 70, -128,
// 127}.
// Scale is: 1/127
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4}));
test::FillValues<float>(
- &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127,
+ &expected, {-102.0 / 127, -64.0 / 127, 0, 38.0 / 127, 102.0 / 127,
70.0 / 127, -128.0 / 127, 1});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
@@ -257,13 +257,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given_V3) {
AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits
// Note that the range is given as [-1, 1].
- // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128,
+ // With int8, the tensor is quantized to {-102, -64, 0, 38, 102, 70, -128,
// 127}.
// Scale is: 1/127
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4}));
test::FillValues<float>(
- &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127,
+ &expected, {-102.0 / 127, -64.0 / 127, 0, 38.0 / 127, 102.0 / 127,
70.0 / 127, -128.0 / 127, 1});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
@@ -285,11 +285,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given) {
AddInputFromArray<float>(TensorShape({}), {1.0}); // Max
// Note that the range is given as [0, 1].
- // With int8, the tensor is quantized to {0, 0, 77, 204}
+ // With int8, the tensor is quantized to {0, 0, 76, 204}
// Scale is: 1/255
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1}));
- test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255});
+ test::FillValues<float>(&expected, {0, 0, 76.0 / 255, 204.0 / 255});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
@@ -311,11 +311,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given_V3) {
AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits
// Note that the range is given as [0, 1].
- // With int8, the tensor is quantized to {0, 0, 77, 204}
+ // With int8, the tensor is quantized to {0, 0, 76, 204}
// Scale is: 1/255
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1}));
- test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255});
+ test::FillValues<float>(&expected, {0, 0, 76.0 / 255, 204.0 / 255});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc
index fdc08ec8e3..64f1b0d661 100644
--- a/tensorflow/core/kernels/spacetobatch_op.cc
+++ b/tensorflow/core/kernels/spacetobatch_op.cc
@@ -42,29 +42,29 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename Device, typename T>
-void SpaceToBatchOpCompute(OpKernelContext* context,
- const Tensor& orig_input_tensor,
- const Tensor& orig_block_shape,
- const Tensor& orig_paddings) {
+Status SpaceToBatchOpCompute(OpKernelContext* context,
+ const Tensor& orig_input_tensor,
+ const Tensor& orig_block_shape,
+ const Tensor& orig_paddings) {
const int input_dims = orig_input_tensor.dims();
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
- errors::InvalidArgument("block_shape rank should be 1 instead of ",
- orig_block_shape.dims()));
+ if (!TensorShapeUtils::IsVector(orig_block_shape.shape())) {
+ return errors::InvalidArgument("block_shape rank should be 1 instead of ",
+ orig_block_shape.dims());
+ }
const int block_dims = orig_block_shape.dim_size(0);
- OP_REQUIRES(
- context, orig_input_tensor.dims() >= 1 + block_dims,
- errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
- " instead of ", orig_input_tensor.dims()));
-
- OP_REQUIRES(context,
- TensorShapeUtils::IsMatrix(orig_paddings.shape()) &&
- block_dims == orig_paddings.dim_size(0) &&
- 2 == orig_paddings.dim_size(1),
- errors::InvalidArgument("paddings should have shape [",
- block_dims, ", 2] instead of ",
- orig_paddings.shape().DebugString()));
+ if (orig_input_tensor.dims() < 1 + block_dims) {
+ return errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
+ " instead of ", orig_input_tensor.dims());
+ }
+
+ if (!(TensorShapeUtils::IsMatrix(orig_paddings.shape()) &&
+ block_dims == orig_paddings.dim_size(0) &&
+ 2 == orig_paddings.dim_size(1))) {
+ return errors::InvalidArgument("paddings should have shape [", block_dims,
+ ", 2] instead of ",
+ orig_paddings.shape().DebugString());
+ }
// To avoid out-of-bounds access in the case that the block_shape and/or
// paddings tensors are concurrently modified, we must copy the values.
@@ -101,22 +101,23 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
- OP_REQUIRES(
- context, block_shape_product > 0,
- errors::InvalidArgument("Product of block sizes must be positive, got ",
- block_shape_product));
+ if (block_shape_product <= 0) {
+ return errors::InvalidArgument(
+ "Product of block sizes must be positive, got ", block_shape_product);
+ }
const int internal_block_dims =
block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
- OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
- errors::InvalidArgument(
- "Maximum number of non-combined block dimensions is ",
- internal_block_dims, " but must not exceed ",
- kMaxSpaceToBatchBlockDims));
+ if (internal_block_dims > kMaxSpaceToBatchBlockDims) {
+ return errors::InvalidArgument(
+ "Maximum number of non-combined block dimensions is ",
+ internal_block_dims, " but must not exceed ",
+ kMaxSpaceToBatchBlockDims);
+ }
if (internal_block_dims == 0) {
context->set_output(0, orig_input_tensor);
- return;
+ return Status::OK();
}
// For the purpose of computing the result, the input will be treated as
@@ -146,16 +147,18 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
const int64 pad_start = paddings[2 * block_dim],
pad_end = paddings[2 * block_dim + 1];
- OP_REQUIRES(context, pad_start >= 0 && pad_end >= 0,
- errors::InvalidArgument("Paddings must be non-negative"));
+ if (pad_start < 0 || pad_end < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
const int64 input_size = orig_input_tensor.dim_size(block_dim + 1);
const int64 block_shape_value = block_shape[block_dim];
const int64 padded_size = input_size + pad_start + pad_end;
- OP_REQUIRES(
- context, padded_size % block_shape_value == 0,
- errors::InvalidArgument("padded_shape[", block_dim, "]=", padded_size,
- " is not divisible by block_shape[", block_dim,
- "]=", block_shape_value));
+ if (padded_size % block_shape_value != 0) {
+ return errors::InvalidArgument("padded_shape[", block_dim,
+ "]=", padded_size,
+ " is not divisible by block_shape[",
+ block_dim, "]=", block_shape_value);
+ }
internal_input_shape.AddDim(input_size);
const int64 output_size = padded_size / block_shape_value;
internal_output_shape.AddDim(output_size);
@@ -174,29 +177,29 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
// Allocate output tensor.
Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
- &output_tensor));
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(0, external_output_shape, &output_tensor));
const int64* internal_paddings = &paddings[2 * removed_prefix_block_dims];
const int64* internal_block_shape = &block_shape[removed_prefix_block_dims];
switch (internal_block_dims) {
-#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
- case NUM_BLOCK_DIMS: { \
- OP_REQUIRES_OK( \
- context, \
- (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \
- context->eigen_device<Device>(), \
- orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
- internal_input_shape.dim_sizes()), \
- internal_block_shape, internal_paddings, \
- output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
- internal_output_shape.dim_sizes())))); \
- } break; \
+#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
+ case NUM_BLOCK_DIMS: { \
+ TF_RETURN_IF_ERROR( \
+ functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \
+ context->eigen_device<Device>(), \
+ orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
+ internal_input_shape.dim_sizes()), \
+ internal_block_shape, internal_paddings, \
+ output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
+ internal_output_shape.dim_sizes()))); \
+ } break; \
/**/
TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_SPACETOBATCH_BLOCK_DIMS_CASE)
#undef TF_SPACETOBATCH_BLOCK_DIMS_CASE
}
+ return Status::OK();
}
} // namespace
@@ -211,8 +214,9 @@ class SpaceToBatchNDOp : public OpKernel {
const Tensor& orig_input_tensor = context->input(0);
const Tensor& orig_block_shape = context->input(1);
const Tensor& orig_paddings = context->input(2);
- SpaceToBatchOpCompute<Device, T>(context, orig_input_tensor,
- orig_block_shape, orig_paddings);
+ OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>(
+ context, orig_input_tensor, orig_block_shape,
+ orig_paddings));
}
};
@@ -241,7 +245,8 @@ class SpaceToBatchOp : public OpKernel {
OP_REQUIRES(context, kRequiredDims == dims,
errors::InvalidArgument("Input rank should be: ", kRequiredDims,
"instead of: ", dims));
- SpaceToBatchOpCompute<Device, T>(context, in0, block_shape_, in1);
+ OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>(
+ context, in0, block_shape_, in1));
}
private:
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 4a6bedbad8..84b47c171f 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -203,10 +203,12 @@ Status ZlibOutputBuffer::Sync() {
}
Status ZlibOutputBuffer::Close() {
- TF_RETURN_IF_ERROR(DeflateBuffered(true));
- TF_RETURN_IF_ERROR(FlushOutputBufferToFile());
- deflateEnd(z_stream_.get());
- z_stream_.reset(nullptr);
+ if (z_stream_) {
+ TF_RETURN_IF_ERROR(DeflateBuffered(true));
+ TF_RETURN_IF_ERROR(FlushOutputBufferToFile());
+ deflateEnd(z_stream_.get());
+ z_stream_.reset(nullptr);
+ }
return Status::OK();
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 267af8b976..b7ca6dbe3d 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -25894,6 +25894,44 @@ op {
}
}
op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+ is_stateful: true
+}
+op {
name: "Igamma"
input_arg {
name: "a"
@@ -27316,6 +27354,30 @@ op {
is_stateful: true
}
op {
+ name: "IteratorGetNextAsOptional"
+ input_arg {
+ name: "iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "IteratorGetNextSync"
input_arg {
name: "iterator"
@@ -35980,6 +36042,64 @@ op {
}
}
op {
+ name: "OptionalFromValue"
+ input_arg {
+ name: "components"
+ type_list_attr: "Toutput_types"
+ }
+ output_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "OptionalGetValue"
+ input_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "OptionalHasValue"
+ input_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "has_value"
+ type: DT_BOOL
+ }
+}
+op {
+ name: "OptionalNone"
+ output_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+}
+op {
name: "OrderedMapClear"
attr {
name: "capacity"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 8c83a09597..28bfa73b6d 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -812,4 +812,33 @@ REGISTER_OP("OptimizeDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("OptionalFromValue")
+ .Input("components: Toutput_types")
+ .Output("optional: variant")
+ .Attr("Toutput_types: list(type) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("OptionalNone")
+ .Output("optional: variant")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("OptionalHasValue")
+ .Input("optional: variant")
+ .Output("has_value: bool")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("OptionalGetValue")
+ .Input("optional: variant")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("IteratorGetNextAsOptional")
+ .Input("iterator: resource")
+ .Output("optional: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index 5f262db2ce..a16ecccf00 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -72,6 +72,7 @@ REGISTER_OP("_If")
.Attr("Tout: list(type)")
.Attr("then_branch: func")
.Attr("else_branch: func")
+ .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
output = cond ? then_branch(input) : else_branch(input)
@@ -98,6 +99,7 @@ REGISTER_OP("If")
.Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
+ .SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
// TODO(drpng): remove this.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7973be88e0..ef167a2e73 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -12466,6 +12466,7 @@ op {
name: "else_branch"
type: "func"
}
+ is_stateful: true
}
op {
name: "Igamma"
@@ -13290,6 +13291,30 @@ op {
is_stateful: true
}
op {
+ name: "IteratorGetNextAsOptional"
+ input_arg {
+ name: "iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "IteratorGetNextSync"
input_arg {
name: "iterator"
@@ -17299,6 +17324,64 @@ op {
}
}
op {
+ name: "OptionalFromValue"
+ input_arg {
+ name: "components"
+ type_list_attr: "Toutput_types"
+ }
+ output_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "OptionalGetValue"
+ input_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "OptionalHasValue"
+ input_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "has_value"
+ type: DT_BOOL
+ }
+}
+op {
+ name: "OptionalNone"
+ output_arg {
+ name: "optional"
+ type: DT_VARIANT
+ }
+}
+op {
name: "OrderedMapClear"
attr {
name: "capacity"
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 549996aaf8..647a797b82 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -74,6 +74,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":compute_engine_metadata_client",
+ ":compute_engine_zone_provider",
":curl_http_request",
":expiring_lru_cache",
":file_block_cache",
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 2e8d13acd5..67c872ac67 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -57,6 +57,7 @@ constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/";
constexpr char kGcsUploadUriBase[] =
"https://www.googleapis.com/upload/storage/v1/";
constexpr char kStorageHost[] = "storage.googleapis.com";
+constexpr char kBucketMetadataLocationKey[] = "location";
constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes.
constexpr int kGetChildrenDefaultPageSize = 1000;
// The HTTP response code "308 Resume Incomplete".
@@ -98,6 +99,11 @@ constexpr uint64 kMatchingPathsCacheDefaultMaxAge = 0;
constexpr char kMatchingPathsCacheMaxEntries[] =
"GCS_MATCHING_PATHS_CACHE_MAX_ENTRIES";
constexpr size_t kMatchingPathsCacheDefaultMaxEntries = 1024;
+// Number of bucket locations cached, most workloads wont touch more than one
+// bucket so this limit is set fairly low
+constexpr size_t kBucketLocationCacheMaxEntries = 10;
+// ExpiringLRUCache doesnt support any "cache forever" option
+constexpr size_t kCacheNeverExpire = std::numeric_limits<uint64>::max();
// The file statistics returned by Stat() for directories.
const FileStatistics DIRECTORY_STAT(0, 0, true);
// Some environments exhibit unreliable DNS resolution. Set this environment
@@ -131,6 +137,14 @@ constexpr char kTokensPerRequest[] = "GCS_TOKENS_PER_REQUEST";
// The environment variable to configure the initial tokens (format: <int64>)
constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS";
+// The environment variable to customize which GCS bucket locations are allowed,
+// if the list is empty defaults to using the region of the zone (format, comma
+// delimited list). Requires 'storage.buckets.get' permission.
+constexpr char kAllowedBucketLocations[] = "GCS_ALLOWED_BUCKET_LOCATIONS";
+// When this value is passed as an allowed location detects the zone tensorflow
+// is running in and restricts to buckets in that region.
+constexpr char kDetectZoneSentinalValue[] = "auto";
+
// TODO: DO NOT use a hardcoded path
Status GetTmpFilename(string* filename) {
#ifndef _WIN32
@@ -603,6 +617,19 @@ bool StringPieceIdentity(StringPiece str, StringPiece* value) {
return true;
}
+/// \brief Utility function to split a comma delimited list of strings to an
+/// unordered set
+bool SplitByCommaToSet(StringPiece list, std::unordered_set<string>* set) {
+ std::vector<string> vector = str_util::Split(list, ",");
+ *set = std::unordered_set<string>(vector.begin(), vector.end());
+ return true;
+}
+
+// \brief Convert Compute Engine zone to region
+string ZoneToRegion(string* zone) {
+ return zone->substr(0, zone->find_last_of('-'));
+}
+
} // namespace
GcsFileSystem::GcsFileSystem() {
@@ -616,6 +643,8 @@ GcsFileSystem::GcsFileSystem() {
std::make_shared<ComputeEngineMetadataClient>(http_request_factory_);
auth_provider_ = std::unique_ptr<AuthProvider>(
new GoogleAuthProvider(compute_engine_metadata_client_));
+ zone_provider_ = std::unique_ptr<ZoneProvider>(
+ new ComputeEngineZoneProvider(compute_engine_metadata_client_));
// Apply the sys env override for the readahead buffer size if it's provided.
if (GetEnvVar(kReadaheadBufferSize, strings::safe_strtou64, &value)) {
@@ -666,6 +695,9 @@ GcsFileSystem::GcsFileSystem() {
matching_paths_cache_.reset(new ExpiringLRUCache<std::vector<string>>(
matching_paths_cache_max_age, matching_paths_cache_max_entries));
+ bucket_location_cache_.reset(new ExpiringLRUCache<string>(
+ kCacheNeverExpire, kBucketLocationCacheMaxEntries));
+
int64 resolve_frequency_secs;
if (GetEnvVar(kResolveCacheSecs, strings::safe_strto64,
&resolve_frequency_secs)) {
@@ -745,24 +777,30 @@ GcsFileSystem::GcsFileSystem() {
}
throttle_.SetConfig(config);
}
+
+ GetEnvVar(kAllowedBucketLocations, SplitByCommaToSet, &allowed_locations_);
}
GcsFileSystem::GcsFileSystem(
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory,
- size_t block_size, size_t max_bytes, uint64 max_staleness,
- uint64 stat_cache_max_age, size_t stat_cache_max_entries,
- uint64 matching_paths_cache_max_age,
+ std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
+ size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age,
+ size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec,
- TimeoutConfig timeouts,
+ TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header)
: auth_provider_(std::move(auth_provider)),
http_request_factory_(std::move(http_request_factory)),
+ zone_provider_(std::move(zone_provider)),
file_block_cache_(
MakeFileBlockCache(block_size, max_bytes, max_staleness)),
stat_cache_(new StatCache(stat_cache_max_age, stat_cache_max_entries)),
matching_paths_cache_(new MatchingPathsCache(
matching_paths_cache_max_age, matching_paths_cache_max_entries)),
+ bucket_location_cache_(new BucketLocationCache(
+ kCacheNeverExpire, kBucketLocationCacheMaxEntries)),
+ allowed_locations_(allowed_locations),
timeouts_(timeouts),
initial_retry_delay_usec_(initial_retry_delay_usec),
additional_header_(additional_header) {}
@@ -771,6 +809,7 @@ Status GcsFileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
+ TF_RETURN_IF_ERROR(CheckBucketLocationConstraint(bucket));
result->reset(new GcsRandomAccessFile(fname, [this, bucket, object](
const string& fname,
uint64 offset, size_t n,
@@ -1072,11 +1111,7 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
}
Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
- std::unique_ptr<HttpRequest> request;
- TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
- request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
- request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
- const Status status = request->Send();
+ const Status status = GetBucketMetadata(bucket, nullptr);
switch (status.code()) {
case errors::Code::OK:
*result = true;
@@ -1089,6 +1124,62 @@ Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
}
}
+Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) {
+ if (allowed_locations_.empty()) {
+ return Status::OK();
+ }
+
+ // Avoid calling external API's in the constructor
+ if (allowed_locations_.erase(kDetectZoneSentinalValue) == 1) {
+ string zone;
+ TF_RETURN_IF_ERROR(zone_provider_->GetZone(&zone));
+ allowed_locations_.insert(ZoneToRegion(&zone));
+ }
+
+ string location;
+ TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location));
+ if (allowed_locations_.find(location) != allowed_locations_.end()) {
+ return Status::OK();
+ }
+
+ return errors::FailedPrecondition(strings::Printf(
+ "Bucket '%s' is in '%s' location, allowed locations are: (%s).",
+ bucket.c_str(), location.c_str(),
+ str_util::Join(allowed_locations_, ", ").c_str()));
+}
+
+Status GcsFileSystem::GetBucketLocation(const string& bucket,
+ string* location) {
+ auto compute_func = [this](const string& bucket, string* location) {
+ std::vector<char> result_buffer;
+ Status status = GetBucketMetadata(bucket, &result_buffer);
+ Json::Value result;
+ TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result));
+ TF_RETURN_IF_ERROR(
+ GetStringValue(result, kBucketMetadataLocationKey, location));
+ return Status::OK();
+ };
+
+ TF_RETURN_IF_ERROR(
+ bucket_location_cache_->LookupOrCompute(bucket, location, compute_func));
+
+ return Status::OK();
+}
+
+Status GcsFileSystem::GetBucketMetadata(const string& bucket,
+ std::vector<char>* result_buffer) {
+ std::unique_ptr<HttpRequest> request;
+ TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
+ request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
+
+ if (result_buffer != nullptr) {
+ request->SetResultBuffer(result_buffer);
+ }
+
+ request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata);
+ return request->Send();
+}
+
Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
StatCache::ComputeFunc compute_func = [this](const string& dirname,
GcsFileStat* stat) {
@@ -1514,6 +1605,7 @@ void GcsFileSystem::FlushCaches() {
file_block_cache_->Flush();
stat_cache_->Clear();
matching_paths_cache_->Clear();
+ bucket_location_cache_->Clear();
}
void GcsFileSystem::SetStats(GcsStatsInterface* stats) {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index a0372286f5..71db707687 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/auth_provider.h"
#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
+#include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h"
#include "tensorflow/core/platform/cloud/expiring_lru_cache.h"
#include "tensorflow/core/platform/cloud/file_block_cache.h"
#include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
@@ -81,14 +82,19 @@ class GcsFileSystem : public FileSystem {
public:
struct TimeoutConfig;
+ // Main constructor used (via RetryingFileSystem) throughout Tensorflow
GcsFileSystem();
+ // Used mostly for unit testing or use cases which need to customize the
+ // filesystem from defaults
GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory,
- size_t block_size, size_t max_bytes, uint64 max_staleness,
+ std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
+ size_t max_bytes, uint64 max_staleness,
uint64 stat_cache_max_age, size_t stat_cache_max_entries,
uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries,
int64 initial_retry_delay_usec, TimeoutConfig timeouts,
+ const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header);
Status NewRandomAccessFile(
@@ -149,6 +155,9 @@ class GcsFileSystem : public FileSystem {
return file_block_cache_->max_staleness();
}
TimeoutConfig timeouts() const { return timeouts_; }
+ std::unordered_set<string> allowed_locations() const {
+ return allowed_locations_;
+ }
string additional_header_name() const {
return additional_header_ ? additional_header_->first : "";
}
@@ -230,6 +239,27 @@ class GcsFileSystem : public FileSystem {
/// 'result' is set if the function returns OK. 'result' cannot be nullptr.
Status BucketExists(const string& bucket, bool* result);
+ /// \brief Retrieves the GCS bucket location. Returns OK if the location was
+ /// retrieved.
+ ///
+ /// Given a string bucket the GCS bucket metadata API will be called and the
+ /// location string filled with the location of the bucket.
+ ///
+ /// This requires the bucket metadata permission.
+ /// Repeated calls for the same bucket are cached so this function can be
+ /// called frequently without causing an extra API call
+ Status GetBucketLocation(const string& bucket, string* location);
+
+ /// \brief Check if the GCS buckets location is allowed with the current
+ /// constraint configuration
+ Status CheckBucketLocationConstraint(const string& bucket);
+
+ /// \brief Given the input bucket `bucket`, fills `result_buffer` with the
+ /// results of the metadata. Returns OK if the API call succeeds without
+ /// error.
+ Status GetBucketMetadata(const string& bucket,
+ std::vector<char>* result_buffer);
+
/// \brief Checks if the object exists. Returns OK if the check succeeded.
///
/// 'result' is set if the function returns OK. 'result' cannot be nullptr.
@@ -277,6 +307,7 @@ class GcsFileSystem : public FileSystem {
mutex mu_;
std::unique_ptr<AuthProvider> auth_provider_ GUARDED_BY(mu_);
std::shared_ptr<HttpRequest::Factory> http_request_factory_;
+ std::unique_ptr<ZoneProvider> zone_provider_;
// block_cache_lock_ protects the file_block_cache_ pointer (Note that
// FileBlockCache instances are themselves threadsafe).
mutex block_cache_lock_;
@@ -292,6 +323,10 @@ class GcsFileSystem : public FileSystem {
using MatchingPathsCache = ExpiringLRUCache<std::vector<string>>;
std::unique_ptr<MatchingPathsCache> matching_paths_cache_;
+ using BucketLocationCache = ExpiringLRUCache<string>;
+ std::unique_ptr<BucketLocationCache> bucket_location_cache_;
+ std::unordered_set<string> allowed_locations_;
+
TimeoutConfig timeouts_;
GcsStatsInterface* stats_ = nullptr; // Not owned.
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index e791ae5a19..ee2b034d74 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -24,6 +24,13 @@ namespace tensorflow {
namespace {
static GcsFileSystem::TimeoutConfig kTestTimeoutConfig(5, 1, 10, 20, 30);
+// Default (empty) constraint config
+static std::unordered_set<string>* kAllowedLocationsDefault =
+ new std::unordered_set<string>();
+// Constraint config if bucket location constraint is turned on, with no
+// custom list
+static std::unordered_set<string>* kAllowedLocationsAuto =
+ new std::unordered_set<string>({"auto"});
class FakeAuthProvider : public AuthProvider {
public:
@@ -33,6 +40,14 @@ class FakeAuthProvider : public AuthProvider {
}
};
+class FakeZoneProvider : public ZoneProvider {
+ public:
+ Status GetZone(string* zone) override {
+ *zone = "us-east1-b";
+ return Status::OK();
+ }
+};
+
TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
@@ -47,15 +62,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
"Range: 6-11\n"
"Timeouts: 5 1 20\n",
"6789")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -74,6 +90,118 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
EXPECT_EQ("6789", result);
}
+TEST(GcsFileSystemTest,
+ NewRandomAccessFile_WithLocationConstraintInSameLocation) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ R"(
+ {
+ "location":"us-east1"
+ })")});
+
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+
+ std::unique_ptr<RandomAccessFile> file;
+ TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
+}
+
+TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ R"(
+ {
+ "location":"us-east1"
+ })"),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/anotherbucket\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ R"(
+ {
+ "location":"us-east1"
+ })"),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ R"(
+ {
+ "location":"us-east1"
+ })")});
+
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+
+ std::unique_ptr<RandomAccessFile> file;
+
+ string bucket = "gs://bucket/random_access.txt";
+ string another_bucket = "gs://anotherbucket/random_access.txt";
+ // Multiple calls should only cause one request to the location api.
+ TF_EXPECT_OK(fs.NewRandomAccessFile(bucket, &file));
+ TF_EXPECT_OK(fs.NewRandomAccessFile(bucket, &file));
+
+ // A new bucket should have one cache miss
+ TF_EXPECT_OK(fs.NewRandomAccessFile(another_bucket, &file));
+ // And then future calls to both should be cached
+ TF_EXPECT_OK(fs.NewRandomAccessFile(bucket, &file));
+ TF_EXPECT_OK(fs.NewRandomAccessFile(another_bucket, &file));
+
+ // Trigger a flush, should then require one more call
+ fs.FlushCaches();
+ TF_EXPECT_OK(fs.NewRandomAccessFile(bucket, &file));
+}
+
+TEST(GcsFileSystemTest,
+ NewRandomAccessFile_WithLocationConstraintInDifferentLocation) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ R"(
+ {
+ "location":"barfoo"
+ })")});
+
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay */, kTestTimeoutConfig,
+ *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+
+ std::unique_ptr<RandomAccessFile> file;
+ EXPECT_EQ(tensorflow::errors::FailedPrecondition(
+ "Bucket 'bucket' is in 'barfoo' location, allowed locations "
+ "are: (us-east1)."),
+ fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
+}
+
TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
@@ -88,15 +216,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) {
"Range: 3-12\n"
"Timeouts: 5 1 20\n",
"3456789")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -151,11 +280,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
+ 18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
@@ -239,11 +369,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
+ 18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
@@ -287,11 +418,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 8 /* block size */, 16 /* max bytes */, 3600 /* max staleness */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
+ 16 /* max bytes */, 3600 /* max staleness */,
3600 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
// There should only be two HTTP requests issued to GCS even though we iterate
@@ -356,11 +489,12 @@ TEST(GcsFileSystemTest,
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
+ 18 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -383,11 +517,13 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -411,15 +547,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) {
"012")});
// Set stat_cache_max_age to 1000s so that StatCache could work.
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 1e3 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 1e3 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Stat the file first so that the file stats are cached.
FileStatistics stat;
@@ -481,11 +618,12 @@ TEST(GcsFileSystemTest, NewWritableFile) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 8 /* block size */, 8 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
+ 8 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Read from the file first, to fill the block cache.
std::unique_ptr<RandomAccessFile> rfile;
@@ -565,15 +703,16 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) {
"Timeouts: 5 1 30\n"
"Put body: t2\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -638,11 +777,13 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 8 /* block size */, 8 /* max bytes */, 3600 /* max staleness */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
+ 8 /* max bytes */, 3600 /* max staleness */,
3600 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Pull the file's first block into the cache. This will trigger the first
// HTTP request to GCS.
std::unique_ptr<RandomAccessFile> rfile;
@@ -719,15 +860,16 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) {
"Timeouts: 5 1 30\n"
"Put body: content1,content2\n",
""));
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 2 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 2 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -776,15 +918,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
"Timeouts: 5 1 30\n"
"Put body: content1,content2\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -805,15 +948,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -866,11 +1010,12 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 32 /* block size */, 32 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 32 /* block size */,
+ 32 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Create an appendable file. This should read the file from GCS, and pull its
// contents into the block cache.
@@ -896,15 +1041,16 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -929,15 +1075,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
"Range: 0-",
content.size() - 1, "\n", "Timeouts: 5 1 20\n"),
content)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
@@ -949,15 +1096,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -972,15 +1120,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt"));
}
@@ -1001,15 +1150,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subfolder/\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder"));
}
@@ -1026,15 +1176,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"size\": \"100\"}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket1"));
TF_EXPECT_OK(fs.FileExists("gs://bucket1/"));
@@ -1055,15 +1206,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": []}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::NOT_FOUND,
fs.FileExists("gs://bucket/path/file1.txt").code());
@@ -1081,15 +1233,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.FileExists("gs://bucket2/").code());
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1123,11 +1276,12 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// The stat cache will ensure that repeated lookups don't trigger additional
// HTTP requests.
@@ -1149,11 +1303,12 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/dir/"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/dir/"));
@@ -1167,15 +1322,16 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1194,15 +1350,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1222,15 +1379,16 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) {
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1249,15 +1407,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1273,15 +1432,16 @@ TEST(GcsFileSystemTest, GetChildren_Root) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children));
@@ -1297,15 +1457,16 @@ TEST(GcsFileSystemTest, GetChildren_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1337,15 +1498,16 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) {
" { \"name\": \"path/file4.txt\" },"
" { \"name\": \"path/file5.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1363,15 +1525,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(
@@ -1390,15 +1553,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result));
@@ -1418,15 +1582,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result));
@@ -1443,15 +1608,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) {
"{\"items\": [ "
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result));
@@ -1468,15 +1634,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result));
@@ -1485,15 +1652,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1518,15 +1686,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Repeated calls to fs.GetMatchingPaths on these patterns should not lead to
// any additional HTTP requests to GCS.
@@ -1560,15 +1729,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// This loop should trigger the first HTTP request to GCS.
for (int i = 0; i < 10; i++) {
@@ -1627,11 +1797,12 @@ TEST(GcsFileSystemTest, DeleteFile) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
+ 16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Do an initial read of the file to load its contents into the block cache.
char scratch[100];
@@ -1650,15 +1821,16 @@ TEST(GcsFileSystemTest, DeleteFile) {
TEST(GcsFileSystemTest, DeleteFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.DeleteFile("gs://bucket/").code());
@@ -1696,11 +1868,12 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
+ 16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Stats the file first so the stat is cached.
FileStatistics stat_before_deletion;
@@ -1721,15 +1894,16 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1749,15 +1923,16 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1768,15 +1943,16 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) {
"name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket"));
}
@@ -1789,15 +1965,16 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/file1.txt\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.DeleteDir("gs://bucket/path/").code());
@@ -1811,15 +1988,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
@@ -1828,15 +2006,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
TEST(GcsFileSystemTest, GetFileSize_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1913,15 +2092,16 @@ TEST(GcsFileSystemTest, RenameFile_Folder) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/"));
}
@@ -2008,11 +2188,12 @@ TEST(GcsFileSystemTest, RenameFile_Object) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 16 /* block size */, 64 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
+ 64 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Do an initial read of the source and destination files to load their
// contents into the block cache.
char scratch[100];
@@ -2088,11 +2269,12 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Do an initial stat of the destination file to load their contents into the
// stat cache.
FileStatistics stat_before_renaming;
@@ -2150,15 +2332,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(
fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt"));
@@ -2191,15 +2374,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) {
"Post: yes\n"
"Timeouts: 5 1 10\n",
"{\"done\": false}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(
errors::Code::UNIMPLEMENTED,
@@ -2215,15 +2399,16 @@ TEST(GcsFileSystemTest, Stat_Object) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
@@ -2248,15 +2433,16 @@ TEST(GcsFileSystemTest, Stat_Folder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"subfolder/\" }]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat));
@@ -2280,15 +2466,16 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code());
@@ -2300,15 +2487,16 @@ TEST(GcsFileSystemTest, Stat_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat));
@@ -2323,15 +2511,16 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code());
@@ -2364,11 +2553,12 @@ TEST(GcsFileSystemTest, Stat_Cache) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Repeated calls to fs.Stat on these paths should not lead to any additional
// HTTP requests to GCS.
@@ -2405,11 +2595,12 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, nullptr /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// There should be a single HTTP request to GCS for fs.Stat in this loop.
for (int i = 0; i < 10; i++) {
FileStatistics stat;
@@ -2437,15 +2628,16 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"5\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat));
@@ -2468,15 +2660,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2498,15 +2691,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2528,15 +2722,16 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": [{\"name\": \"subfolder/\"}]}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/"));
@@ -2554,15 +2749,16 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/"));
@@ -2574,15 +2770,16 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code());
}
@@ -2615,15 +2812,16 @@ TEST(GcsFileSystemTest, CreateDir_Folder) {
"Timeouts: 5 1 30\n"
"Put body: \n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/"));
@@ -2641,15 +2839,16 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket"));
@@ -2712,15 +2911,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -2804,15 +3004,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) {
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -2838,15 +3039,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay*/, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
EXPECT_EQ(error::Code::NOT_FOUND,
@@ -2857,6 +3059,29 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
EXPECT_EQ(1, undeleted_dirs);
}
+TEST(GcsFileSystemTest, NoConstraintsEnvironmentVariableTest) {
+ unsetenv("GCS_ALLOWED_BUCKET_LOCATIONS");
+ // No constraints
+ GcsFileSystem fs1;
+ EXPECT_EQ(*kAllowedLocationsDefault, fs1.allowed_locations());
+
+ // Cover cache initialization code, any uninitialized cache will cause this to
+ // fail
+ fs1.FlushCaches();
+}
+
+TEST(GcsFileSystemTest, BucketLocationConstraintEnvironmentVariableTest) {
+ unsetenv("GCS_ALLOWED_BUCKET_LOCATIONS");
+ setenv("GCS_ALLOWED_BUCKET_LOCATIONS", "auto", 1);
+ GcsFileSystem fs1;
+ EXPECT_EQ(*kAllowedLocationsAuto, fs1.allowed_locations());
+
+ setenv("GCS_ALLOWED_BUCKET_LOCATIONS", "custom,list", 1);
+ GcsFileSystem fs2;
+ EXPECT_EQ(std::unordered_set<string>({"custom", "list"}),
+ fs2.allowed_locations());
+}
+
TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) {
GcsFileSystem fs1;
EXPECT_EQ("", fs1.additional_header_name());
@@ -2902,11 +3127,12 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) {
std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, add_header /* gcs additional header */);
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ add_header /* gcs additional header */);
std::unique_ptr<HttpRequest> request;
TF_EXPECT_OK(fs7.CreateHttpRequest(&request));
@@ -2973,15 +3199,16 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
"Auth Token: fake_token\n"
"Header Hello: world\n",
"{}")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<HttpRequest> request;
TF_EXPECT_OK(fs.CreateHttpRequest(&request));
@@ -3035,15 +3262,16 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
@@ -3061,15 +3289,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) {
"Range: 0-5\n"
"Timeouts: 5 1 20\n",
"012345")});
- GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
+ 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index e17ecc8c52..5b237c4736 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -232,8 +232,11 @@ class Env {
// TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
// provide a routine to get the absolute time.
+ /// \brief Returns the number of nano-seconds since the Unix epoch.
+ virtual uint64 NowNanos() { return envTime->NowNanos(); }
+
/// \brief Returns the number of micro-seconds since the Unix epoch.
- virtual uint64 NowMicros() { return envTime->NowMicros(); };
+ virtual uint64 NowMicros() { return envTime->NowMicros(); }
/// \brief Returns the number of seconds since the Unix epoch.
virtual uint64 NowSeconds() { return envTime->NowSeconds(); }
diff --git a/tensorflow/core/platform/s3/s3_crypto.cc b/tensorflow/core/platform/s3/s3_crypto.cc
deleted file mode 100644
index d7062a59d2..0000000000
--- a/tensorflow/core/platform/s3/s3_crypto.cc
+++ /dev/null
@@ -1,113 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/platform/s3/s3_crypto.h"
-#include <openssl/hmac.h>
-#include <openssl/sha.h>
-
-#include <aws/core/utils/crypto/HashResult.h>
-#include <aws/s3/S3Client.h>
-
-namespace tensorflow {
-
-class S3Sha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
- public:
- S3Sha256HMACOpenSSLImpl() {}
-
- virtual ~S3Sha256HMACOpenSSLImpl() = default;
-
- virtual Aws::Utils::Crypto::HashResult Calculate(
- const Aws::Utils::ByteBuffer& toSign,
- const Aws::Utils::ByteBuffer& secret) override {
- unsigned int length = SHA256_DIGEST_LENGTH;
- Aws::Utils::ByteBuffer digest(length);
- memset(digest.GetUnderlyingData(), 0, length);
-
- HMAC_CTX ctx;
- HMAC_CTX_init(&ctx);
-
- HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
- static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
- HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
- HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
- HMAC_CTX_cleanup(&ctx);
-
- return Aws::Utils::Crypto::HashResult(std::move(digest));
- }
-};
-
-class S3Sha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
- public:
- S3Sha256OpenSSLImpl() {}
-
- virtual ~S3Sha256OpenSSLImpl() = default;
-
- virtual Aws::Utils::Crypto::HashResult Calculate(
- const Aws::String& str) override {
- SHA256_CTX sha256;
- SHA256_Init(&sha256);
- SHA256_Update(&sha256, str.data(), str.size());
-
- Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
- SHA256_Final(hash.GetUnderlyingData(), &sha256);
-
- return Aws::Utils::Crypto::HashResult(std::move(hash));
- }
-
- virtual Aws::Utils::Crypto::HashResult Calculate(
- Aws::IStream& stream) override {
- SHA256_CTX sha256;
- SHA256_Init(&sha256);
-
- auto currentPos = stream.tellg();
- if (currentPos == std::streampos(std::streamoff(-1))) {
- currentPos = 0;
- stream.clear();
- }
-
- stream.seekg(0, stream.beg);
-
- char streamBuffer
- [Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
- while (stream.good()) {
- stream.read(streamBuffer,
- Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
- auto bytesRead = stream.gcount();
-
- if (bytesRead > 0) {
- SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
- }
- }
-
- stream.clear();
- stream.seekg(currentPos, stream.beg);
-
- Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
- SHA256_Final(hash.GetUnderlyingData(), &sha256);
-
- return Aws::Utils::Crypto::HashResult(std::move(hash));
- }
-};
-
-std::shared_ptr<Aws::Utils::Crypto::Hash>
-S3SHA256Factory::CreateImplementation() const {
- return Aws::MakeShared<S3Sha256OpenSSLImpl>(S3CryptoAllocationTag);
-}
-
-std::shared_ptr<Aws::Utils::Crypto::HMAC>
-S3SHA256HmacFactory::CreateImplementation() const {
- return Aws::MakeShared<S3Sha256HMACOpenSSLImpl>(S3CryptoAllocationTag);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_crypto.h b/tensorflow/core/platform/s3/s3_crypto.h
deleted file mode 100644
index e376b8b0c0..0000000000
--- a/tensorflow/core/platform/s3/s3_crypto.h
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <aws/core/Aws.h>
-#include <aws/core/utils/crypto/Factories.h>
-#include <aws/core/utils/crypto/HMAC.h>
-#include <aws/core/utils/crypto/Hash.h>
-
-namespace tensorflow {
-static const char* S3CryptoAllocationTag = "S3CryptoAllocation";
-
-class S3SHA256Factory : public Aws::Utils::Crypto::HashFactory {
- public:
- std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
- const override;
-};
-
-class S3SHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
- public:
- std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
- const override;
-};
-
-} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index bdc8f808df..d5f5dec390 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -187,9 +187,7 @@ class S3RandomAccessFile : public RandomAccessFile {
return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
}
n = getObjectOutcome.GetResult().GetContentLength();
- std::stringstream ss;
- ss << getObjectOutcome.GetResult().GetBody().rdbuf();
- ss.read(scratch, n);
+ getObjectOutcome.GetResult().GetBody().read(scratch, n);
*result = StringPiece(scratch, n);
return Status::OK();
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index a3bc2f422e..74058c8465 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -466,6 +466,11 @@ message RecvBufRequest {
// Optional, for annotating the timeline.
string src_device = 8;
string dst_device = 9;
+
+ // Depending on the RPC system in use, it may be necessary to set this
+ // id to detect resends of RPCs where the server is not aware that
+ // the prior RPC failed.
+ int64 request_id = 10;
}
message RecvBufResponse {
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index cea5e8ffb0..6f564e7e1e 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 9
+#define TF_MINOR_VERSION 10
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc1"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h
index 53087821d7..973e315f09 100644
--- a/tensorflow/core/util/ctc/ctc_beam_entry.h
+++ b/tensorflow/core/util/ctc/ctc_beam_entry.h
@@ -1,3 +1,4 @@
+// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -145,3 +146,4 @@ class BeamComparer {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
+// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h)
diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h
index 2579198ece..1a622babe1 100644
--- a/tensorflow/core/util/ctc/ctc_beam_scorer.h
+++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h
@@ -1,3 +1,4 @@
+// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -73,3 +74,4 @@ class BaseBeamScorer {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_
+// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h)
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index 709c65fc96..aee647a1b3 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -418,3 +418,4 @@ Status CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::TopPaths(
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
+// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h)
diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h
index b8bab69053..3be36822e5 100644
--- a/tensorflow/core/util/ctc/ctc_decoder.h
+++ b/tensorflow/core/util/ctc/ctc_decoder.h
@@ -1,3 +1,4 @@
+// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -112,3 +113,4 @@ class CTCGreedyDecoder : public CTCDecoder {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
+// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h)
diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h
index 50f8f49f1c..36be9e92ef 100644
--- a/tensorflow/core/util/ctc/ctc_loss_util.h
+++ b/tensorflow/core/util/ctc/ctc_loss_util.h
@@ -1,3 +1,4 @@
+// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -46,3 +47,4 @@ inline float LogSumExp(float log_prob_1, float log_prob_2) {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
+// LINT.ThenChange(//tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h)
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 3ce7988057..418e97ac24 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -325,9 +325,9 @@ bool ParseExample(protobuf::io::CodedInputStream* stream,
while (!stream->ExpectAtEnd()) {
if (!stream->ExpectTag(kDelimitedTag(1))) {
if (!SkipExtraneousTag(stream)) return false;
- continue;
+ } else {
+ if (!ParseFeatures(stream, example)) return false;
}
- if (!ParseFeatures(stream, example)) return false;
}
return true;
}
@@ -1455,5 +1455,773 @@ Status FastParseSingleExample(const Config& config, const string& serialized,
return Status::OK();
}
+// Return the number of bytes elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
+ string* out) {
+ int num_elements = 0;
+ uint32 length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
+ return -1;
+ }
+ if (length > 0) {
+ auto limit = stream->PushLimit(length);
+ while (!stream->ExpectAtEnd()) {
+ uint32 bytes_length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) ||
+ !stream->ReadVarint32(&bytes_length) ||
+ (out != nullptr && !stream->ReadString(out++, bytes_length))) {
+ return -1;
+ }
+ if (out == nullptr) {
+ stream->Skip(bytes_length);
+ }
+ num_elements++;
+ }
+ stream->PopLimit(limit);
+ }
+ return num_elements;
+}
+
+inline void PadFloatFeature(int num_to_pad, float* out) {
+ for (int i = 0; i < num_to_pad; i++) {
+ *out++ = 0.0;
+ }
+}
+
+inline void PadInt64Feature(int num_to_pad, int64* out) {
+ for (int i = 0; i < num_to_pad; i++) {
+ *out++ = 0;
+ }
+}
+
+// Return the number of float elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
+ float* out) {
+ int num_elements = 0;
+ uint32 length;
+ if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
+ return -1;
+ }
+ if (length > 0) {
+ auto limit = stream->PushLimit(length);
+ uint8 peek_tag = PeekTag(stream);
+ if (peek_tag == kDelimitedTag(1)) { // packed
+ uint32 packed_length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) ||
+ !stream->ReadVarint32(&packed_length)) {
+ return -1;
+ }
+ auto packed_limit = stream->PushLimit(packed_length);
+ while (!stream->ExpectAtEnd()) {
+ uint32 buffer32;
+ if (!stream->ReadLittleEndian32(&buffer32)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = bit_cast<float>(buffer32);
+ }
+ num_elements++;
+ }
+ stream->PopLimit(packed_limit);
+ } else if (peek_tag == kFixed32Tag(1)) {
+ while (!stream->ExpectAtEnd()) {
+ uint32 buffer32;
+ if (!stream->ExpectTag(kFixed32Tag(1)) ||
+ !stream->ReadLittleEndian32(&buffer32)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = bit_cast<float>(buffer32);
+ }
+ num_elements++;
+ }
+ } else {
+ // Unknown tag.
+ return -1;
+ }
+ stream->PopLimit(limit);
+ }
+ return num_elements;
+}
+
+// Return the number of int64 elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
+ int64* out) {
+ int num_elements = 0;
+ uint32 length;
+ if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
+ return -1;
+ }
+ if (length > 0) {
+ auto limit = stream->PushLimit(length);
+ uint8 peek_tag = PeekTag(stream);
+ if (peek_tag == kDelimitedTag(1)) { // packed
+ uint32 packed_length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) ||
+ !stream->ReadVarint32(&packed_length)) {
+ return -1;
+ }
+ auto packed_limit = stream->PushLimit(packed_length);
+ while (!stream->ExpectAtEnd()) {
+ protobuf_uint64 n; // There is no API for int64
+ if (!stream->ReadVarint64(&n)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = n;
+ }
+ num_elements++;
+ }
+ stream->PopLimit(packed_limit);
+ } else if (peek_tag == kVarintTag(1)) {
+ while (!stream->ExpectAtEnd()) {
+ protobuf_uint64 n; // There is no API for int64
+ if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = n;
+ }
+ num_elements++;
+ }
+ } else {
+ // Unknown tag.
+ return -1;
+ }
+ stream->PopLimit(limit);
+ }
+ return num_elements;
+}
+
+inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
+ uint8 peek_tag = PeekTag(stream);
+ switch (peek_tag) {
+ case kDelimitedTag(1):
+ return DT_STRING;
+ case kDelimitedTag(2):
+ return DT_FLOAT;
+ case kDelimitedTag(3):
+ return DT_INT64;
+ default:
+ return DT_INVALID;
+ }
+}
+
+inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
+ DataType dtype) {
+ switch (dtype) {
+ case DT_STRING:
+ if (!stream->ExpectTag(kDelimitedTag(1))) {
+ return false;
+ }
+ break;
+ case DT_FLOAT:
+ if (!stream->ExpectTag(kDelimitedTag(2))) {
+ return false;
+ }
+ break;
+ case DT_INT64:
+ if (!stream->ExpectTag(kDelimitedTag(3))) {
+ return false;
+ }
+ break;
+ default:
+ return false;
+ }
+ uint32 length;
+ return stream->ReadVarint32(&length) && length == 0;
+}
+
+// TODO(sundberg): Use the threadpool to parallelize example parsing.
+Status FastParseSequenceExample(
+ const FastParseExampleConfig& context_config,
+ const FastParseExampleConfig& feature_list_config,
+ gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
+ thread::ThreadPool* thread_pool, Result* context_result,
+ Result* feature_list_result) {
+ int num_examples = serialized.size();
+ DCHECK(context_result != nullptr);
+ DCHECK(feature_list_result != nullptr);
+ std::map<StringPiece, bool> context_is_sparse;
+ std::map<StringPiece, std::pair<DataType, size_t>>
+ context_feature_type_and_lengths;
+ if (!example_names.empty() && example_names.size() != num_examples) {
+ return errors::InvalidArgument(
+ "example_names must be empty or have the correct number of elements");
+ }
+ for (auto& c : context_config.sparse) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ context_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ context_is_sparse[c.feature_name] = true;
+ }
+ for (auto& c : context_config.dense) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ context_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ context_is_sparse[c.feature_name] = false;
+ }
+ std::map<StringPiece, bool> sequence_is_sparse;
+ std::map<StringPiece, std::pair<DataType, size_t>>
+ sequence_feature_type_and_lengths;
+ for (auto& c : feature_list_config.sparse) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ sequence_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ sequence_is_sparse[c.feature_name] = true;
+ }
+ for (auto& c : feature_list_config.dense) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ sequence_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ sequence_is_sparse[c.feature_name] = false;
+ }
+
+ std::vector<std::map<StringPiece, StringPiece>> all_context_features(
+ num_examples);
+ std::vector<std::map<StringPiece, StringPiece>> all_sequence_features(
+ num_examples);
+ const string kUnknown = "<unknown>";
+ for (int d = 0; d < num_examples; d++) {
+ const string& example = serialized[d];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[d];
+ auto* context_features = &all_context_features[d];
+ auto* sequence_features = &all_sequence_features[d];
+
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(example.data()), example.size());
+ // Not clear what this does. Why not stream.EnableAliasing()?
+ EnableAliasing(&stream);
+
+ // Extract pointers to all features within this serialized example.
+ while (!stream.ExpectAtEnd()) {
+ std::map<StringPiece, StringPiece>* features = nullptr;
+ const std::map<StringPiece, std::pair<DataType, size_t>>* config =
+ nullptr;
+ if (stream.ExpectTag(kDelimitedTag(1))) {
+ // Context
+ features = context_features;
+ config = &context_feature_type_and_lengths;
+ } else if (stream.ExpectTag(kDelimitedTag(2))) {
+ // Sequence
+ features = sequence_features;
+ config = &sequence_feature_type_and_lengths;
+ } else if (!SkipExtraneousTag(&stream)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ if (features != nullptr) {
+ uint32 length;
+ if (!stream.ReadVarint32(&length)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ auto limit = stream.PushLimit(length);
+ while (!stream.ExpectAtEnd()) {
+ StringPiece key, value;
+ uint32 length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&length)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ auto limit = stream.PushLimit(length);
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !ParseString(&stream, &key) ||
+ !stream.ExpectTag(kDelimitedTag(2)) ||
+ !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ stream.PopLimit(limit);
+ // Only save if this feature was requested.
+ if (config->count(key) > 0) {
+ (*features)[key] = value;
+ }
+ }
+ stream.PopLimit(limit);
+ }
+ }
+
+ for (const auto& c : *context_features) {
+ size_t num_elements = 0;
+ if (!c.second.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
+ EnableAliasing(&stream);
+ DataType dtype = context_feature_type_and_lengths[c.first].first;
+ int64 num;
+ switch (dtype) {
+ case DT_STRING:
+ num = ParseBytesFeature(&stream, nullptr);
+ break;
+ case DT_FLOAT:
+ num = ParseFloatFeature(&stream, nullptr);
+ break;
+ case DT_INT64:
+ num = ParseInt64Feature(&stream, nullptr);
+ break;
+ default:
+ num = -1;
+ break;
+ }
+ if (num == -1) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in context feature ", c.first,
+ " in example ", example_name));
+ }
+ num_elements += num;
+ }
+ if (context_is_sparse[c.first]) {
+ context_feature_type_and_lengths[c.first].second += num_elements;
+ } else {
+ size_t current_max = context_feature_type_and_lengths[c.first].second;
+ context_feature_type_and_lengths[c.first].second =
+ std::max(current_max, num_elements);
+ }
+ }
+ for (const auto& c : *sequence_features) {
+ size_t num_elements = 0;
+ if (!c.second.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
+ EnableAliasing(&stream);
+ DataType dtype = sequence_feature_type_and_lengths[c.first].first;
+ while (!stream.ExpectAtEnd()) {
+ uint32 feature_length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&feature_length)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ if (feature_length > 2) {
+ auto limit = stream.PushLimit(feature_length);
+ int64 num;
+ switch (dtype) {
+ case DT_STRING:
+ num = ParseBytesFeature(&stream, nullptr);
+ break;
+ case DT_FLOAT:
+ num = ParseFloatFeature(&stream, nullptr);
+ break;
+ case DT_INT64:
+ num = ParseInt64Feature(&stream, nullptr);
+ break;
+ default:
+ num = -1;
+ break;
+ }
+ if (num == -1) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ num_elements += num;
+ stream.PopLimit(limit);
+ } else if (feature_length == 2) {
+ if (!SkipEmptyFeature(&stream, dtype)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ } else if (feature_length != 0) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ }
+ }
+ if (sequence_is_sparse[c.first]) {
+ sequence_feature_type_and_lengths[c.first].second += num_elements;
+ } else {
+ size_t current_max = sequence_feature_type_and_lengths[c.first].second;
+ sequence_feature_type_and_lengths[c.first].second =
+ std::max(current_max, num_elements);
+ }
+ }
+ }
+
+ // Allocate memory.
+ context_result->sparse_values.resize(context_config.sparse.size());
+ context_result->sparse_indices.resize(context_config.sparse.size());
+ context_result->sparse_shapes.resize(context_config.sparse.size());
+ context_result->dense_values.resize(context_config.dense.size());
+ feature_list_result->sparse_values.resize(feature_list_config.sparse.size());
+ feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
+ feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
+ feature_list_result->dense_values.resize(feature_list_config.dense.size());
+ int t = 0;
+ for (const auto& c : context_config.dense) {
+ TensorShape dense_shape;
+ DataType dtype = c.dtype;
+ size_t expected_max_elements =
+ context_feature_type_and_lengths[c.feature_name].second;
+ if (expected_max_elements != dense_shape.num_elements()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Inconsistent number of elements for feature ", c.feature_name));
+ }
+ dense_shape.AddDim(num_examples);
+ for (const int dim : c.shape.dim_sizes()) {
+ dense_shape.AddDim(dim);
+ }
+ context_result->dense_values[t] = Tensor(dtype, dense_shape);
+
+ // TODO(sundberg): Refactor to reduce code duplication, and add bounds
+ // checking for the outputs.
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = context_result->dense_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = context_result->dense_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = context_result->dense_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ t++;
+
+ // Fill in the values.
+ for (int e = 0; e < num_examples; e++) {
+ size_t num_elements = 0;
+ const auto& feature = all_context_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ }
+ if (num_elements != expected_max_elements) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected number of elements in example ", example_name));
+ }
+ }
+ }
+ t = 0;
+ for (const auto& c : context_config.sparse) {
+ TensorShape indices_shape, values_shape;
+ DataType dtype = c.dtype;
+ size_t expected_num_elements =
+ context_feature_type_and_lengths[c.feature_name].second;
+ indices_shape.AddDim(expected_num_elements);
+ indices_shape.AddDim(2);
+ values_shape.AddDim(expected_num_elements);
+ context_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
+ context_result->sparse_values[t] = Tensor(dtype, values_shape);
+ context_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({2}));
+ // TODO(sundberg): Refactor to reduce code duplication, and add bounds
+ // checking for the outputs.
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = context_result->sparse_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = context_result->sparse_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = context_result->sparse_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
+ auto out_shape = context_result->sparse_shapes[t].vec<int64>();
+ t++;
+
+ // Fill in the values.
+ size_t num_elements = 0;
+ size_t max_num_cols = 0;
+ for (int e = 0; e < num_examples; e++) {
+ const auto& feature = all_context_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ max_num_cols = std::max(max_num_cols, num_added);
+ for (int i = 0; i < num_added; i++) {
+ *out_indices++ = e;
+ *out_indices++ = i;
+ }
+ }
+ }
+ if (num_elements != expected_num_elements) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected total number of elements in feature ", c.feature_name));
+ }
+ out_shape(0) = num_examples;
+ out_shape(1) = max_num_cols;
+ }
+ t = 0;
+ for (const auto& c : feature_list_config.dense) {
+ TensorShape dense_shape, row_shape;
+ DataType dtype = c.dtype;
+ size_t expected_max_elements =
+ sequence_feature_type_and_lengths[c.feature_name].second;
+ int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
+ if (!c.shape.AsTensorShape(&row_shape) ||
+ expected_max_elements != expected_max_rows * row_shape.num_elements()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected shape error in feature ", c.feature_name));
+ }
+ dense_shape.AddDim(num_examples);
+ dense_shape.AddDim(expected_max_rows);
+ for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
+ dense_shape.AddDim(dim);
+ }
+ feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
+
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = feature_list_result->dense_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = feature_list_result->dense_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ t++;
+
+ // Fill in the values.
+ for (int e = 0; e < num_examples; e++) {
+ size_t num_elements = 0;
+ const auto& feature = all_sequence_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ while (!stream.ExpectAtEnd()) {
+ uint32 feature_length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&feature_length)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ auto limit = stream.PushLimit(feature_length);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ if (num_added != row_shape.num_elements()) {
+ return errors::InvalidArgument(
+ "Unexpected number of elements in feature ", c.feature_name,
+ ", example ", example_name);
+ }
+ stream.PopLimit(limit);
+ }
+ }
+ // Pad as necessary.
+ int num_to_pad = expected_max_elements - num_elements;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes += num_to_pad;
+ break;
+ case DT_FLOAT:
+ PadFloatFeature(num_to_pad, out_float);
+ out_float += num_to_pad;
+ break;
+ case DT_INT64:
+ PadInt64Feature(num_to_pad, out_int64);
+ out_int64 += num_to_pad;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ }
+ }
+ t = 0;
+ for (const auto& c : feature_list_config.sparse) {
+ TensorShape indices_shape, values_shape;
+ DataType dtype = c.dtype;
+ size_t expected_num_elements =
+ sequence_feature_type_and_lengths[c.feature_name].second;
+ indices_shape.AddDim(expected_num_elements);
+ indices_shape.AddDim(3);
+ values_shape.AddDim(expected_num_elements);
+ feature_list_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
+ feature_list_result->sparse_values[t] = Tensor(dtype, values_shape);
+ feature_list_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({3}));
+
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = feature_list_result->sparse_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = feature_list_result->sparse_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ int64* out_indices =
+ feature_list_result->sparse_indices[t].flat<int64>().data();
+ auto out_shape = feature_list_result->sparse_shapes[t].vec<int64>();
+ t++;
+
+ // Fill in the values.
+ size_t num_elements = 0;
+ size_t max_num_rows = 0;
+ size_t max_num_cols = 0;
+ for (int e = 0; e < num_examples; e++) {
+ const auto& feature = all_sequence_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ size_t num_rows = 0;
+ while (!stream.ExpectAtEnd()) {
+ uint32 feature_length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&feature_length)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ if (feature_length > 2) {
+ auto limit = stream.PushLimit(feature_length);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ max_num_cols = std::max(max_num_cols, num_added);
+ for (int i = 0; i < num_added; i++) {
+ *out_indices++ = e;
+ *out_indices++ = num_rows;
+ *out_indices++ = i;
+ }
+ stream.PopLimit(limit);
+ } else if (feature_length == 2) {
+ if (!SkipEmptyFeature(&stream, dtype)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ } else if (feature_length != 0) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ num_rows++;
+ }
+ max_num_rows = std::max(max_num_rows, num_rows);
+ }
+ }
+ if (num_elements != expected_num_elements) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected number of elements in feature ", c.feature_name));
+ }
+ out_shape(0) = num_examples;
+ out_shape(1) = max_num_rows;
+ out_shape(2) = max_num_cols;
+ }
+
+ return Status::OK();
+}
+
} // namespace example
} // namespace tensorflow
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index 1b08f02267..024a4518ee 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -85,6 +85,17 @@ typedef FastParseExampleConfig FastParseSingleExampleConfig;
Status FastParseSingleExample(const FastParseSingleExampleConfig& config,
const string& serialized, Result* result);
+// Parses a batch of serialized SequenceExample protos and converts them into
+// result according to given config.
+// Given example names have to either be empty or the same size as serialized.
+// example_names are used only for error messages.
+Status FastParseSequenceExample(
+ const example::FastParseExampleConfig& context_config,
+ const example::FastParseExampleConfig& feature_list_config,
+ gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
+ thread::ThreadPool* thread_pool, example::Result* context_result,
+ example::Result* feature_list_result);
+
// This function parses serialized Example and populates given example.
// It uses the same specialized parser as FastParseExample which is efficient.
// But then constructs Example which is relatively slow.
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 566a42dbd5..d90f85e422 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -1882,9 +1882,8 @@ class MklPrimitive {
public:
virtual ~MklPrimitive() {}
- // Dummy data. Its size, hard-coded as 256 here, does
- // not matter since MKL should never operate on this buffer.
- unsigned char DummyData[256];
+ // Dummy data which MKL DNN never operates on
+ unsigned char* DummyData = nullptr;
};
const mkldnn::memory::dims NONE_DIMS = {};
@@ -1896,8 +1895,9 @@ class MklPrimitiveFactory {
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
- auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
- if (stream_iter == MklPrimitiveFactory<T>::GetHashMap().end()) {
+ auto& map = MklPrimitiveFactory<T>::GetHashMap();
+ auto stream_iter = map.find(key);
+ if (stream_iter == map.end()) {
return nullptr;
} else {
CHECK(stream_iter->second != nullptr) << "nullptr present in map";
@@ -1906,11 +1906,12 @@ class MklPrimitiveFactory {
}
void SetOp(const string& key, MklPrimitive* op) {
- auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
+ auto& map = MklPrimitiveFactory<T>::GetHashMap();
+ auto stream_iter = map.find(key);
- CHECK(stream_iter == MklPrimitiveFactory<T>::GetHashMap().end());
+ CHECK(stream_iter == map.end());
- MklPrimitiveFactory<T>::GetHashMap()[key] = op;
+ map[key] = op;
}
private:
diff --git a/tensorflow/docs_src/guide/custom_estimators.md b/tensorflow/docs_src/guide/custom_estimators.md
index a63e2bafb3..6e4ef2e0f2 100644
--- a/tensorflow/docs_src/guide/custom_estimators.md
+++ b/tensorflow/docs_src/guide/custom_estimators.md
@@ -149,7 +149,7 @@ model. This configuration step is similar to how we configured the @{tf.estimato
```python
classifier = tf.estimator.Estimator(
- model_fn=my_model,
+ model_fn=my_model_fn,
params={
'feature_columns': my_feature_columns,
# Two hidden layers of 10 nodes each.
@@ -474,7 +474,7 @@ Instantiate the custom Estimator through the Estimator base class as follows:
```python
# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.estimator.Estimator(
- model_fn=my_model,
+ model_fn=my_model_fn,
params={
'feature_columns': my_feature_columns,
# Two hidden layers of 10 nodes each.
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index cf869e8655..5e26facaba 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -38,7 +38,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for macOS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.9.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.10.0-rc1.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index 4ec7e42773..a59c2741e1 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.9.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.10.0-rc1.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index c5f760d254..e9c6650c92 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
</dependency>
```
@@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
</dependency>
</dependencies>
</project>
@@ -124,12 +124,12 @@ instead:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
</dependency>
```
@@ -148,7 +148,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or macOS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.10.0-rc1.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -167,7 +167,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.9.0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.10.0-rc1.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -175,10 +175,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.10.0-rc1.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.10.0-rc1.zip).
3. Extract this .zip file.
__Note__: The native library (`tensorflow_jni.dll`) requires `msvcp140.dll` at runtime, which is included in the [Visual C++ 2015 Redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=48145) package.
@@ -227,7 +227,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.9.0.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.10.0-rc1.jar HelloTF.java</b></pre>
### Running
@@ -241,11 +241,11 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and macOS X:
-<pre><b>java -cp libtensorflow-1.9.0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.10.0-rc1.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.9.0.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.10.0-rc1.jar;. -Djava.library.path=jni HelloTF</b></pre>
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 3a9a01c57e..005ad437bc 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -436,7 +436,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0rc1-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -650,13 +650,13 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0rc1-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0rc1-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -667,13 +667,13 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0rc1-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0rc1-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -684,13 +684,13 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0rc1-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0rc1-cp35-cp35m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -701,13 +701,13 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0rc1-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0rc1-cp36-cp36m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index 1a7b2b815d..3a8637bfb1 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py3-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0rc1-py3-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -242,7 +242,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py3-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0rc1-py3-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
@@ -350,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (<i>targetDirectory</i>)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0rc1-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -517,7 +517,7 @@ The value you specify depends on your Python version.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0rc1-py2-none-any.whl
</pre>
@@ -525,5 +525,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py2-none-any.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0rc1-py3-none-any.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 31dcad64d4..a7c0b6970a 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -374,10 +374,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl`
file depends on your platform. For example, the following command will install
the pip package
-for TensorFlow 1.9.0 on Linux:
+for TensorFlow 1.10.0rc1 on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.10.0rc1-py2-none-any.whl</b>
</pre>
## Validate your installation
@@ -483,6 +483,8 @@ the error message, ask a new question on Stack Overflow and specify the
**Linux**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.10.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.15.0</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.10.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.15.0</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.11.0</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow_gpu-1.9.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.11.0</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.10.0</td><td>N/A</td><td>N/A</td></tr>
@@ -508,6 +510,7 @@ the error message, ask a new question on Stack Overflow and specify the
**Mac**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.10.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.15.0</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.11.0</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.10.1</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.7.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.10.1</td><td>N/A</td><td>N/A</td></tr>
@@ -525,6 +528,8 @@ the error message, ask a new question on Stack Overflow and specify the
**Windows**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.10.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.10.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.9.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow_gpu-1.9.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.8.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md
index 6724d1eaf8..7202ef47f7 100644
--- a/tensorflow/docs_src/performance/xla/jit.md
+++ b/tensorflow/docs_src/performance/xla/jit.md
@@ -19,10 +19,11 @@ on the `XLA_CPU` or `XLA_GPU` TensorFlow devices. Placing operators directly on
a TensorFlow XLA device forces the operator to run on that device and is mainly
used for testing.
-> Note: The XLA CPU backend produces fast single-threaded code (in most cases),
-> but does not yet parallelize as well as the TensorFlow CPU backend. The XLA
-> GPU backend is competitive with the standard TensorFlow implementation,
-> sometimes faster, sometimes slower.
+> Note: The XLA CPU backend supports intra-op parallelism (i.e. it can shard a
+> single operation across multiple cores) but it does not support inter-op
+> parallelism (i.e. it cannot execute independent operations concurrently across
+> multiple cores). The XLA GPU backend is competitive with the standard
+> TensorFlow implementation, sometimes faster, sometimes slower.
### Turning on JIT compilation
@@ -55,8 +56,7 @@ sess = tf.Session(config=config)
> Note: Turning on JIT at the session level will not result in operations being
> compiled for the CPU. JIT compilation for CPU operations must be done via
-> the manual method documented below. This decision was made due to the CPU
-> backend being single-threaded.
+> the manual method documented below.
#### Manual
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 5f7482f90f..edc777a3c7 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1431,19 +1431,29 @@ complete and returns the received data.
See also
[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-Applies a reduction function to an array.
+Applies a reduction function to one or more arrays in parallel.
-<b> `Reduce(operand, init_value, computation, dimensions)` </b>
+<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
-Arguments | Type | Semantics
-------------- | ---------------- | ---------------------------------------
-`operand` | `XlaOp` | array of type `T`
-`init_value` | `XlaOp` | scalar of type `T`
-`computation` | `XlaComputation` | computation of type `T, T -> T`
-`dimensions` | `int64` array | unordered array of dimensions to reduce
+Arguments | Type | Semantics
+------------- | --------------------- | ---------------------------------------
+`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
+`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
+`computation` | `XlaComputation` | computation of type
+ : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
+`dimensions` | `int64` array | unordered array of dimensions to reduce
-This operation reduces one or more dimensions of the input array into scalars.
-The rank of the returned array is `rank(operand) - len(dimensions)`.
+Where:
+* N is required to be greater or equal to 1.
+* All input arrays must have the same dimensions.
+* If `N = 1`, `Collate(T)` is `T`.
+* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`.
+
+The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
+`T_i`, the dimensions of which are described below.
+
+This operation reduces one or more dimensions of each input array into scalars.
+The rank of each returned array is `rank(operand) - len(dimensions)`.
`init_value` is the initial value used for every reduction and may be inserted
anywhere during computation by the back-end. In most cases, `init_value` is an
identity of the reduction function (for example, 0 for addition). The applied
@@ -1459,9 +1469,9 @@ enough to being associative for most practical uses. It is possible to conceive
of some completely non-associative reductions, however, and these will produce
incorrect or unpredictable results in XLA reductions.
-As an example, when reducing across the one dimension in a 1D array with values
-[10, 11, 12, 13], with reduction function `f` (this is `computation`) then that
-could be computed as
+As an example, when reducing across one dimension in a single 1D array with
+values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
+then that could be computed as
`f(10, f(11, f(12, f(init_value, 13)))`
@@ -1543,6 +1553,34 @@ the 1D array `| 20 28 36 |`.
Reducing the 3D array over all its dimensions produces the scalar `84`.
+When `N > 1`, reduce function application is slightly more complex, as it is
+applied simultaneously to all inputs. For example, consider the following
+reduction function, which can be used to compute the max and the argmax of a
+a 1-D tensor in parallel:
+
+```
+f: (Float, Int, Float, Int) -> Float, Int
+f(max, argmax, value, index):
+ if value >= argmax:
+ return (value, index)
+ else:
+ return (max, argmax)
+```
+
+For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
+`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
+input dimension is equivalent to the following recursive application:
+```
+f_0 = f(I_V, I_K, V_0, K_0)
+f_1 = f(f_0.first, f_0.second, V_1, K_1)
+...
+f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
+```
+
+Applying this reduction to an array of values, and an array of sequential
+indices (i.e. iota), will co-iterate over the arrays, and return a tuple
+containing the maximal value and the matching index.
+
## ReducePrecision
See also
@@ -1801,6 +1839,138 @@ is implementation-defined.
: : : limit of interval :
| `shape` | `Shape` | Output shape of type T |
+## Scatter
+
+The XLA scatter operation generates a result which is the value of the input
+tensor `operand`, with several slices (at indices specified by
+`scatter_indices`) updated with the values in `updates` using
+`update_computation`.
+
+See also
+[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+
+<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b>
+
+|Arguments | Type | Semantics |
+|------------------|------------------------|----------------------------------|
+|`operand` | `XlaOp` | Tensor to be scattered into. |
+|`scatter_indices` | `XlaOp` | Tensor containing the starting |
+: : : indices of the slices that must :
+: : : be scattered to. :
+|`updates` | `XlaOp` | Tensor containing the values that|
+: : : must be used for scattering. :
+|`update_computation`| `XlaComputation` | Computation to be used for |
+: : : combining the existing values in :
+: : : the input tensor and the updates :
+: : : during scatter. This computation :
+: : : should be of type `T, T -> T`. :
+|`index_vector_dim`| `int64` | The dimension in |
+: : : `scatter_indices` that contains :
+: : : the starting indices. :
+|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in |
+: : : `updates` shape that are _window :
+: : : dimensions_. :
+|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ |
+: : : that must be inserted into :
+: : : `updates` shape. :
+|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from |
+: : : the scatter indices to the :
+: : : operand index space. This array :
+: : : is interpreted as mapping `i` to :
+: : : `scatter_dims_to_operand_dims[i]`:
+: : : . It has to be one-to-one and :
+: : : total. :
+
+If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
+`scatter_indices` to have a trailing `1` dimension.
+
+We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
+dimensions in `updates` shape that are not in `update_window_dims`, in ascending
+order.
+
+The arguments of scatter should follow these constraints:
+
+ - `updates` tensor must be of rank `update_window_dims.size +
+ scatter_indices.rank - 1`.
+
+ - Bounds of dimension `i` in `updates` must conform to the following:
+ - If `i` is present in `update_window_dims` (i.e. equal to
+ `update_window_dims`[`k`] for some `k`), then the bound of dimension
+ `i` in `updates` must not exceed the corresponding bound of `operand`
+ after accounting for the `inserted_window_dims` (i.e.
+ `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
+ the bounds of `operand` with the bounds at indices
+ `inserted_window_dims` removed).
+ - If `i` is present in `update_scatter_dims` (i.e. equal to
+ `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
+ `i` in `updates` must be equal to the corresponding bound of
+ `scatter_indices`, skipping `index_vector_dim` (i.e.
+ `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
+ `scatter_indices.shape.dims`[`k+1`] otherwise).
+
+ - `update_window_dims` must be in ascending order, not have any repeating
+ dimension numbers, and be in the range `[0, updates.rank)`.
+
+ - `inserted_window_dims` must be in ascending order, not have any
+ repeating dimension numbers, and be in the range `[0, operand.rank)`.
+
+ - `scatter_dims_to_operand_dims.size` must be equal to
+ `scatter_indices`[`index_vector_dim`], and its values must be in the range
+ `[0, operand.rank)`.
+
+For a given index `U` in the `updates` tensor, the corresponding index `I` in
+the `operand` tensor into which this update has to be applied is computed as
+follows:
+
+ 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
+ an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] =
+ `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
+ positions `index_vector_dim` into A.
+ 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
+ `S` using the `scatter_dims_to_operand_dims` map. More formally:
+ 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
+ `k` < `scatter_dims_to_operand_dims.size`.
+ 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
+ at `update_window_dims` in `U` according to `inserted_window_dims`.
+ More formally:
+ 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if
+ `k` < `update_window_dims.size`, where `window_dims_to_operand_dims`
+ is the monotonic function with domain [`0`, `update_window_dims.size`)
+ and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For
+ example, if `update_window_dims.size` is `4`, `operand.rank` is `6`,
+ and `inserted_window_dims` is {`0`, `2`} then
+ `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`,
+ `3`→`5`}).
+ 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+ addition.
+
+In summary, the scatter operation can be defined as follows.
+
+ - Initialize `output` with `operand`, i.e. for all indices `O` in the
+ `operand` tensor:\
+ `output`[`O`] = `operand`[`O`]
+ - For every index `U` in the `updates` tensor and the corresponding index `O`
+ in the `operand` tensor:\
+ `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`])
+
+The order in which updates are applied is non-deterministic. So, when multiple
+indices in `updates` refer to the same index in `operand`, the corresponding
+value in `output` will be non-deterministic.
+
+Note that the first parameter that is passed into the `update_computation` will
+always be the current value from the `output` tensor and the second parameter
+will always be the value from the `updates` tensor. This is important
+specifically for cases when the `update_computation` is _not commutative_.
+
+Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
+the scatter op updates the elements in the input that are extracted by the
+corresponding gather op.
+
+For a detailed informal description and examples, refer to the
+"Informal Description" section under `Gather`.
+
## Select
See also
diff --git a/tensorflow/docs_src/performance/xla/tfcompile.md b/tensorflow/docs_src/performance/xla/tfcompile.md
index 8521d7eacb..e4b803164f 100644
--- a/tensorflow/docs_src/performance/xla/tfcompile.md
+++ b/tensorflow/docs_src/performance/xla/tfcompile.md
@@ -205,10 +205,7 @@ representing the inputs, `results` representing the outputs, and `temps`
representing temporary buffers used internally to perform the computation. By
default, each instance of the generated class allocates and manages all of these
buffers for you. The `AllocMode` constructor argument may be used to change this
-behavior. A convenience library is provided in
-[`tensorflow/compiler/aot/runtime.h`](https://www.tensorflow.org/code/tensorflow/compiler/aot/runtime.h)
-to help with manual buffer allocation; usage of this library is optional. All
-buffers should be aligned to 32-byte boundaries.
+behavior. All buffers are aligned to 64-byte boundaries.
The generated C++ class is just a wrapper around the low-level code generated by
XLA.
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 1e765d1cd7..ca1521e641 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -334,8 +334,12 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua
// the given `shape` according to indices. This operator is the inverse of the
// @{tf.gather_nd} operator which extracts values or slices from a given tensor.
//
+// If `indices` contains duplicates, then their updates are accumulated (summed).
+//
// **WARNING**: The order in which updates are applied is nondeterministic, so the
-// output will be nondeterministic if `indices` contains duplicates.
+// output will be nondeterministic if `indices` contains duplicates -- because
+// of some numerical approximation issues, numbers summed in different order
+// may yield different results.
//
// `indices` is an integer tensor containing indices into a new tensor of shape
// `shape`. The last dimension of `indices` can be at most the rank of `shape`:
@@ -3258,6 +3262,127 @@ func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf
return op.Output(0)
}
+// DecodeWavAttr is an optional argument to DecodeWav.
+type DecodeWavAttr func(optionalAttr)
+
+// DecodeWavDesiredChannels sets the optional desired_channels attribute to value.
+//
+// value: Number of sample channels wanted.
+// If not specified, defaults to -1
+func DecodeWavDesiredChannels(value int64) DecodeWavAttr {
+ return func(m optionalAttr) {
+ m["desired_channels"] = value
+ }
+}
+
+// DecodeWavDesiredSamples sets the optional desired_samples attribute to value.
+//
+// value: Length of audio requested.
+// If not specified, defaults to -1
+func DecodeWavDesiredSamples(value int64) DecodeWavAttr {
+ return func(m optionalAttr) {
+ m["desired_samples"] = value
+ }
+}
+
+// Decode a 16-bit PCM WAV file to a float tensor.
+//
+// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float.
+//
+// When desired_channels is set, if the input contains fewer channels than this
+// then the last channel will be duplicated to give the requested number, else if
+// the input has more channels than requested then the additional channels will be
+// ignored.
+//
+// If desired_samples is set, then the audio will be cropped or padded with zeroes
+// to the requested length.
+//
+// The first output contains a Tensor with the content of the audio samples. The
+// lowest dimension will be the number of channels, and the second will be the
+// number of samples. For example, a ten-sample-long stereo WAV file should give an
+// output shape of [10, 2].
+//
+// Arguments:
+// contents: The WAV-encoded audio, usually from a file.
+//
+// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header.
+func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeWav",
+ Input: []tf.Input{
+ contents,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// UnbatchAttr is an optional argument to Unbatch.
+type UnbatchAttr func(optionalAttr)
+
+// UnbatchContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func UnbatchContainer(value string) UnbatchAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// UnbatchSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func UnbatchSharedName(value string) UnbatchAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Reverses the operation of Batch for a single output Tensor.
+//
+// An instance of Unbatch either receives an empty batched_tensor, in which case it
+// asynchronously waits until the values become available from a concurrently
+// running instance of Unbatch with the same container and shared_name, or receives
+// a non-empty batched_tensor in which case it finalizes all other concurrently
+// running instances and outputs its own element from the batch.
+//
+// batched_tensor: The possibly transformed output of Batch. The size of the first
+// dimension should remain unchanged by the transformations for the operation to
+// work.
+// batch_index: The matching batch_index obtained from Batch.
+// id: The id scalar emitted by Batch.
+// unbatched_tensor: The Tensor corresponding to this execution.
+// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the
+// batched input tensor associated with a given invocation of the op.
+// container: Container to control resource sharing.
+// shared_name: Instances of Unbatch with the same container and shared_name are
+// assumed to possibly belong to the same batch. If left empty, the op name will
+// be used as the shared name.
+func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"timeout_micros": timeout_micros}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Unbatch",
+ Input: []tf.Input{
+ batched_tensor, batch_index, id,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the mean along sparse segments of a tensor.
//
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
@@ -7376,6 +7501,272 @@ func Acos(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// UnbatchGradAttr is an optional argument to UnbatchGrad.
+type UnbatchGradAttr func(optionalAttr)
+
+// UnbatchGradContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func UnbatchGradContainer(value string) UnbatchGradAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// UnbatchGradSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func UnbatchGradSharedName(value string) UnbatchGradAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Gradient of Unbatch.
+//
+// Acts like Batch but using the given batch_index index of batching things as they
+// become available. This ensures that the gradients are propagated back in the
+// same session which did the forward pass.
+//
+// original_input: The input to the Unbatch operation this is the gradient of.
+// batch_index: The batch_index given to the Unbatch operation this is the gradient
+// of.
+// grad: The downstream gradient.
+// id: The id scalar emitted by Batch.
+// batched_grad: The return value, either an empty tensor or the batched gradient.
+// container: Container to control resource sharing.
+// shared_name: Instances of UnbatchGrad with the same container and shared_name
+// are assumed to possibly belong to the same batch. If left empty, the op name
+// will be used as the shared name.
+func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UnbatchGrad",
+ Input: []tf.Input{
+ original_input, batch_index, grad, id,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad.
+type AvgPool3DGradAttr func(optionalAttr)
+
+// AvgPool3DGradDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format of the input and output data. With the
+// default format "NDHWC", the data is stored in the order of:
+// [batch, in_depth, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCDHW", the data storage order is:
+// [batch, in_channels, in_depth, in_height, in_width].
+// If not specified, defaults to "NDHWC"
+func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes gradients of average pooling function.
+//
+// Arguments:
+// orig_input_shape: The original input dimensions.
+// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
+// ksize: 1-D tensor of length 5. The size of the window for each dimension of
+// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
+// strides: 1-D tensor of length 5. The stride of the sliding window for each
+// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+// padding: The type of padding algorithm to use.
+//
+// Returns The backprop for input.
+func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AvgPool3DGrad",
+ Input: []tf.Input{
+ orig_input_shape, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample.
+type ParseSingleSequenceExampleAttr func(optionalAttr)
+
+// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
+//
+// value: A list of Ncontext_sparse types; the data types of data in
+// each context Feature given in context_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_sparse_types"] = value
+ }
+}
+
+// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_types"] = value
+ }
+}
+
+// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
+//
+// value: A list of Ncontext_dense shapes; the shapes of data in
+// each context Feature given in context_dense_keys.
+// The number of elements in the Feature corresponding to context_dense_key[j]
+// must always equal context_dense_shapes[j].NumEntries().
+// The shape of context_dense_values[j] will match context_dense_shapes[j].
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_dense_shapes"] = value
+ }
+}
+
+// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
+//
+// value: A list of Nfeature_list_sparse types; the data types
+// of data in each FeatureList given in feature_list_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_sparse_types"] = value
+ }
+}
+
+// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
+//
+// value: A list of Nfeature_list_dense shapes; the shapes of
+// data in each FeatureList given in feature_list_dense_keys.
+// The shape of each Feature in the FeatureList corresponding to
+// feature_list_dense_key[j] must always equal
+// feature_list_dense_shapes[j].NumEntries().
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_shapes"] = value
+ }
+}
+
+// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors.
+//
+// Arguments:
+// serialized: A scalar containing a binary serialized SequenceExample proto.
+// feature_list_dense_missing_assumed_empty: A vector listing the
+// FeatureList keys which may be missing from the SequenceExample. If the
+// associated FeatureList is missing, it is treated as empty. By default,
+// any FeatureList not listed in this vector must exist in the SequenceExample.
+// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
+// The keys expected in the Examples' features associated with context_sparse
+// values.
+// context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' context features associated with
+// dense values.
+// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
+// (scalars). The keys expected in the FeatureLists associated with sparse
+// values.
+// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' feature_lists associated
+// with lists of dense values.
+// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
+// context_dense_defaults[j] provides default values
+// when the SequenceExample's context map lacks context_dense_key[j].
+// If an empty Tensor is provided for context_dense_defaults[j],
+// then the Feature context_dense_keys[j] is required.
+// The input type is inferred from context_dense_defaults[j], even when it's
+// empty. If context_dense_defaults[j] is not empty, its shape must match
+// context_dense_shapes[j].
+// debug_name: A scalar containing the name of the serialized proto.
+// May contain, for example, table key (descriptive) name for the
+// corresponding serialized proto. This is purely useful for debugging
+// purposes, and the presence of values here has no effect on the output.
+// May also be an empty scalar if no name is available.
+func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ParseSingleSequenceExample",
+ Input: []tf.Input{
+ serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleSequenceExample", err)
+ return
+ }
+ return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values
+}
+
// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize.
type QuantizeAndDequantizeAttr func(optionalAttr)
@@ -8283,6 +8674,101 @@ func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ..
return op.Output(0)
}
+// Encode audio data using the WAV file format.
+//
+// This operation will generate a string suitable to be saved out to create a .wav
+// audio file. It will be encoded in the 16-bit PCM format. It takes in float
+// values in the range -1.0f to 1.0f, and any outside that value will be clamped to
+// that range.
+//
+// `audio` is a 2-D float Tensor of shape `[length, channels]`.
+// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
+//
+// Arguments:
+// audio: 2-D with shape `[length, channels]`.
+// sample_rate: Scalar containing the sample frequency.
+//
+// Returns 0-D. WAV-encoded file contents.
+func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeWav",
+ Input: []tf.Input{
+ audio, sample_rate,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes atan of x element-wise.
+func Atan(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Atan",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
+type ResourceApplyAdaMaxAttr func(optionalAttr)
+
+// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, m, and v tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the AdaMax algorithm.
+//
+// m_t <- beta1 * m_{t-1} + (1 - beta1) * g
+// v_t <- max(beta2 * v_{t-1}, abs(g))
+// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
+//
+// Arguments:
+// var_: Should be from a Variable().
+// m: Should be from a Variable().
+// v: Should be from a Variable().
+// beta1_power: Must be a scalar.
+// lr: Scaling factor. Must be a scalar.
+// beta1: Momentum factor. Must be a scalar.
+// beta2: Momentum factor. Must be a scalar.
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+//
+// Returns the created operation.
+func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyAdaMax",
+ Input: []tf.Input{
+ var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// AssertAttr is an optional argument to Assert.
type AssertAttr func(optionalAttr)
@@ -9253,7 +9739,7 @@ func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr {
// 8 elements. In Python, that update would look like this:
//
// ```python
-// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True)
// indices = tf.constant([[4], [3], [1] ,[7]])
// updates = tf.constant([9, 10, 11, 12])
// update = tf.scatter_nd_add(ref, indices, updates)
@@ -10457,101 +10943,6 @@ func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Outpu
return op.Output(0), op.Output(1)
}
-// Computes atan of x element-wise.
-func Atan(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Atan",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
-type ResourceApplyAdaMaxAttr func(optionalAttr)
-
-// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, m, and v tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the AdaMax algorithm.
-//
-// m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-// v_t <- max(beta2 * v_{t-1}, abs(g))
-// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
-//
-// Arguments:
-// var_: Should be from a Variable().
-// m: Should be from a Variable().
-// v: Should be from a Variable().
-// beta1_power: Must be a scalar.
-// lr: Scaling factor. Must be a scalar.
-// beta1: Momentum factor. Must be a scalar.
-// beta2: Momentum factor. Must be a scalar.
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
-//
-// Returns the created operation.
-func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyAdaMax",
- Input: []tf.Input{
- var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// Encode audio data using the WAV file format.
-//
-// This operation will generate a string suitable to be saved out to create a .wav
-// audio file. It will be encoded in the 16-bit PCM format. It takes in float
-// values in the range -1.0f to 1.0f, and any outside that value will be clamped to
-// that range.
-//
-// `audio` is a 2-D float Tensor of shape `[length, channels]`.
-// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
-//
-// Arguments:
-// audio: 2-D with shape `[length, channels]`.
-// sample_rate: Scalar containing the sample frequency.
-//
-// Returns 0-D. WAV-encoded file contents.
-func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "EncodeWav",
- Input: []tf.Input{
- audio, sample_rate,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Converts each string in the input Tensor to its hash mod by a number of buckets.
//
// The hash function is deterministic on the content of the string within the
@@ -12399,7 +12790,7 @@ func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
// 8 elements. In Python, that update would look like this:
//
// ```python
-// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
// indices = tf.constant([[4], [3], [1] ,[7]])
// updates = tf.constant([9, 10, 11, 12])
// update = tf.scatter_nd_update(ref, indices, updates)
@@ -21581,7 +21972,7 @@ func PaddedBatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.O
return op.Output(0)
}
-// Returns element-wise smallest integer in not less than x.
+// Returns element-wise smallest integer not less than x.
func Ceil(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
return
@@ -24308,6 +24699,145 @@ func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Out
return op.Output(0), op.Output(1)
}
+// BatchAttr is an optional argument to Batch.
+type BatchAttr func(optionalAttr)
+
+// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value.
+// If not specified, defaults to 10
+func BatchMaxEnqueuedBatches(value int64) BatchAttr {
+ return func(m optionalAttr) {
+ m["max_enqueued_batches"] = value
+ }
+}
+
+// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value.
+// If not specified, defaults to <>
+func BatchAllowedBatchSizes(value []int64) BatchAttr {
+ return func(m optionalAttr) {
+ m["allowed_batch_sizes"] = value
+ }
+}
+
+// BatchContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func BatchContainer(value string) BatchAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// BatchSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func BatchSharedName(value string) BatchAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// BatchBatchingQueue sets the optional batching_queue attribute to value.
+// If not specified, defaults to ""
+func BatchBatchingQueue(value string) BatchAttr {
+ return func(m optionalAttr) {
+ m["batching_queue"] = value
+ }
+}
+
+// Batches all input tensors nondeterministically.
+//
+// When many instances of this Op are being run concurrently with the same
+// container/shared_name in the same device, some will output zero-shaped Tensors
+// and others will output Tensors of size up to max_batch_size.
+//
+// All Tensors in in_tensors are batched together (so, for example, labels and
+// features should be batched with a single instance of this operation.
+//
+// Each invocation of batch emits an `id` scalar which will be used to identify
+// this particular invocation when doing unbatch or its gradient.
+//
+// Each op which emits a non-empty batch will also emit a non-empty batch_index
+// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id,
+// start, and length of elements of each set of Tensors present in batched_tensors.
+//
+// Batched tensors are concatenated along the first dimension, and all tensors in
+// in_tensors must have the first dimension of the same size.
+//
+// in_tensors: The tensors to be batched.
+// num_batch_threads: Number of scheduling threads for processing batches of work.
+// Determines the number of batches processed in parallel.
+// max_batch_size: Batch sizes will never be bigger than this.
+// batch_timeout_micros: Maximum number of microseconds to wait before outputting
+// an incomplete batch.
+// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does
+// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad
+// batches up to one of those sizes. The entries must increase monotonically, and
+// the final entry must equal max_batch_size.
+// grad_timeout_micros: The timeout to use for the gradient. See Unbatch.
+// batched_tensors: Either empty tensors or a batch of concatenated Tensors.
+// batch_index: If out_tensors is non-empty, has information to invert it.
+// container: Controls the scope of sharing of this batch.
+// id: always contains a scalar with a unique ID for this invocation of Batch.
+// shared_name: Concurrently running instances of batch in the same device with the
+// same container and shared_name will batch their elements together. If left
+// empty, the op name will be used as the shared name.
+// T: the types of tensors to be batched.
+func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Batch",
+ Input: []tf.Input{
+ tf.OutputList(in_tensors),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil {
+ scope.UpdateErr("Batch", err)
+ return
+ }
+ batch_index = op.Output(idx)
+ id = op.Output(idx)
+ return batched_tensors, batch_index, id
+}
+
+// Adjust the hue of one or more images.
+//
+// `images` is a tensor of at least 3 dimensions. The last dimension is
+// interpretted as channels, and must be three.
+//
+// The input image is considered in the RGB colorspace. Conceptually, the RGB
+// colors are first mapped into HSV. A delta is then applied all the hue values,
+// and then remapped back to RGB colorspace.
+//
+// Arguments:
+// images: Images to adjust. At least 3-D.
+// delta: A float delta to add to the hue.
+//
+// Returns The hue-adjusted image or images.
+func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AdjustHue",
+ Input: []tf.Input{
+ images, delta,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam.
type ResourceApplyAdamAttr func(optionalAttr)
@@ -25358,6 +25888,73 @@ func NonMaxSuppressionV3(scope *Scope, boxes tf.Output, scores tf.Output, max_ou
return op.Output(0)
}
+// NonMaxSuppressionV4Attr is an optional argument to NonMaxSuppressionV4.
+type NonMaxSuppressionV4Attr func(optionalAttr)
+
+// NonMaxSuppressionV4PadToMaxOutputSize sets the optional pad_to_max_output_size attribute to value.
+//
+// value: If true, the output `selected_indices` is padded to be of length
+// `max_output_size`. Defaults to false.
+// If not specified, defaults to false
+func NonMaxSuppressionV4PadToMaxOutputSize(value bool) NonMaxSuppressionV4Attr {
+ return func(m optionalAttr) {
+ m["pad_to_max_output_size"] = value
+ }
+}
+
+// Greedily selects a subset of bounding boxes in descending order of score,
+//
+// pruning away boxes that have high intersection-over-union (IOU) overlap
+// with previously selected boxes. Bounding boxes with score less than
+// `score_threshold` are removed. Bounding boxes are supplied as
+// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
+// diagonal pair of box corners and the coordinates can be provided as normalized
+// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
+// is agnostic to where the origin is in the coordinate system and more
+// generally is invariant to orthogonal transformations and translations
+// of the coordinate system; thus translating or reflections of the coordinate
+// system result in the same boxes being selected by the algorithm.
+// The output of this operation is a set of integers indexing into the input
+// collection of bounding boxes representing the selected boxes. The bounding
+// box coordinates corresponding to the selected indices can then be obtained
+// using the `tf.gather operation`. For example:
+// selected_indices = tf.image.non_max_suppression_v2(
+// boxes, scores, max_output_size, iou_threshold, score_threshold)
+// selected_boxes = tf.gather(boxes, selected_indices)
+//
+// Arguments:
+// boxes: A 2-D float tensor of shape `[num_boxes, 4]`.
+// scores: A 1-D float tensor of shape `[num_boxes]` representing a single
+// score corresponding to each box (each row of boxes).
+// max_output_size: A scalar integer tensor representing the maximum number of
+// boxes to be selected by non max suppression.
+// iou_threshold: A 0-D float tensor representing the threshold for deciding whether
+// boxes overlap too much with respect to IOU.
+// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove
+// boxes based on score.
+//
+// Returns A 1-D integer tensor of shape `[M]` representing the selected
+// indices from the boxes tensor, where `M <= max_output_size`.A 0-D integer tensor representing the number of valid elements in
+// `selected_indices`, with the valid elements appearing first.
+func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...NonMaxSuppressionV4Attr) (selected_indices tf.Output, valid_outputs tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "NonMaxSuppressionV4",
+ Input: []tf.Input{
+ boxes, scores, max_output_size, iou_threshold, score_threshold,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// Computes the matrix logarithm of one or more square matrices:
//
//
@@ -25560,132 +26157,6 @@ func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source
return op.Output(0), op.Output(1)
}
-// DecodeProtoV2Attr is an optional argument to DecodeProtoV2.
-type DecodeProtoV2Attr func(optionalAttr)
-
-// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value.
-//
-// value: Either the special value `local://` or a path to a file containing
-// a serialized `FileDescriptorSet`.
-// If not specified, defaults to "local://"
-func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr {
- return func(m optionalAttr) {
- m["descriptor_source"] = value
- }
-}
-
-// DecodeProtoV2MessageFormat sets the optional message_format attribute to value.
-//
-// value: Either `binary` or `text`.
-// If not specified, defaults to "binary"
-func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr {
- return func(m optionalAttr) {
- m["message_format"] = value
- }
-}
-
-// DecodeProtoV2Sanitize sets the optional sanitize attribute to value.
-//
-// value: Whether to sanitize the result or not.
-// If not specified, defaults to false
-func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr {
- return func(m optionalAttr) {
- m["sanitize"] = value
- }
-}
-
-// The op extracts fields from a serialized protocol buffers message into tensors.
-//
-// The `decode_proto` op extracts fields from a serialized protocol buffers
-// message into tensors. The fields in `field_names` are decoded and converted
-// to the corresponding `output_types` if possible.
-//
-// A `message_type` name must be provided to give context for the field
-// names. The actual message descriptor can be looked up either in the
-// linked-in descriptor pool or a filename provided by the caller using
-// the `descriptor_source` attribute.
-//
-// Each output tensor is a dense tensor. This means that it is padded to
-// hold the largest number of repeated elements seen in the input
-// minibatch. (The shape is also padded by one to prevent zero-sized
-// dimensions). The actual repeat counts for each example in the
-// minibatch can be found in the `sizes` output. In many cases the output
-// of `decode_proto` is fed immediately into tf.squeeze if missing values
-// are not a concern. When using tf.squeeze, always pass the squeeze
-// dimension explicitly to avoid surprises.
-//
-// For the most part, the mapping between Proto field types and
-// TensorFlow dtypes is straightforward. However, there are a few
-// special cases:
-//
-// - A proto field that contains a submessage or group can only be converted
-// to `DT_STRING` (the serialized submessage). This is to reduce the
-// complexity of the API. The resulting string can be used as input
-// to another instance of the decode_proto op.
-//
-// - TensorFlow lacks support for unsigned integers. The ops represent uint64
-// types as a `DT_INT64` with the same twos-complement bit pattern
-// (the obvious way). Unsigned int32 values can be represented exactly by
-// specifying type `DT_INT64`, or using twos-complement if the caller
-// specifies `DT_INT32` in the `output_types` attribute.
-//
-// The `descriptor_source` attribute selects a source of protocol
-// descriptors to consult when looking up `message_type`. This may be a
-// filename containing a serialized `FileDescriptorSet` message,
-// or the special value `local://`, in which case only descriptors linked
-// into the code will be searched; the filename can be on any filesystem
-// accessible to TensorFlow.
-//
-// You can build a `descriptor_source` file using the `--descriptor_set_out`
-// and `--include_imports` options to the protocol compiler `protoc`.
-//
-// The `local://` database only covers descriptors linked into the
-// code via C++ libraries, not Python imports. You can link in a proto descriptor
-// by creating a cc_library target with alwayslink=1.
-//
-// Both binary and text proto serializations are supported, and can be
-// chosen using the `format` attribute.
-//
-// Arguments:
-// bytes: Tensor of serialized protos with shape `batch_shape`.
-// message_type: Name of the proto message type to decode.
-// field_names: List of strings containing proto field names.
-// output_types: List of TF types to use for the respective field in field_names.
-//
-// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`.
-// Each entry is the number of values found for the corresponding field.
-// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field.
-// `values[i]` has datatype `output_types[i]`
-// and shape `[batch_shape, max(sizes[...,i])]`.
-func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DecodeProtoV2",
- Input: []tf.Input{
- bytes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- sizes = op.Output(idx)
- if values, idx, err = makeOutputList(op, idx, "values"); err != nil {
- scope.UpdateErr("DecodeProtoV2", err)
- return
- }
- return sizes, values
-}
-
// Creates a dataset that splits a SparseTensor into elements row-wise.
func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
if scope.Err() != nil {
@@ -26651,30 +27122,6 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out
return op.Output(0)
}
-// Creates a dataset that executes a SQL query and emits rows of the result set.
-//
-// Arguments:
-// driver_name: The database type. Currently, the only supported type is 'sqlite'.
-// data_source_name: A connection string to connect to the database.
-// query: A SQL query to execute.
-//
-//
-func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
- opspec := tf.OpSpec{
- Type: "SqlDataset",
- Input: []tf.Input{
- driver_name, data_source_name, query,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that emits the records from one or more binary files.
//
// Arguments:
@@ -26966,7 +27413,7 @@ func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output)
return op.Output(0)
}
-// Gets the next output from the given iterator.
+// Gets the next output from the given iterator .
func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
if scope.Err() != nil {
return
@@ -27374,6 +27821,241 @@ func SinkDataset(scope *Scope, input_dataset tf.Output) (handle tf.Output) {
return op.Output(0)
}
+// Constructs an Optional variant from a tuple of tensors.
+func OptionalFromValue(scope *Scope, components []tf.Output) (optional tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "OptionalFromValue",
+ Input: []tf.Input{
+ tf.OutputList(components),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DecodeProtoV2Attr is an optional argument to DecodeProtoV2.
+type DecodeProtoV2Attr func(optionalAttr)
+
+// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value.
+//
+// value: Either the special value `local://` or a path to a file containing
+// a serialized `FileDescriptorSet`.
+// If not specified, defaults to "local://"
+func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr {
+ return func(m optionalAttr) {
+ m["descriptor_source"] = value
+ }
+}
+
+// DecodeProtoV2MessageFormat sets the optional message_format attribute to value.
+//
+// value: Either `binary` or `text`.
+// If not specified, defaults to "binary"
+func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr {
+ return func(m optionalAttr) {
+ m["message_format"] = value
+ }
+}
+
+// DecodeProtoV2Sanitize sets the optional sanitize attribute to value.
+//
+// value: Whether to sanitize the result or not.
+// If not specified, defaults to false
+func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr {
+ return func(m optionalAttr) {
+ m["sanitize"] = value
+ }
+}
+
+// The op extracts fields from a serialized protocol buffers message into tensors.
+//
+// The `decode_proto` op extracts fields from a serialized protocol buffers
+// message into tensors. The fields in `field_names` are decoded and converted
+// to the corresponding `output_types` if possible.
+//
+// A `message_type` name must be provided to give context for the field
+// names. The actual message descriptor can be looked up either in the
+// linked-in descriptor pool or a filename provided by the caller using
+// the `descriptor_source` attribute.
+//
+// Each output tensor is a dense tensor. This means that it is padded to
+// hold the largest number of repeated elements seen in the input
+// minibatch. (The shape is also padded by one to prevent zero-sized
+// dimensions). The actual repeat counts for each example in the
+// minibatch can be found in the `sizes` output. In many cases the output
+// of `decode_proto` is fed immediately into tf.squeeze if missing values
+// are not a concern. When using tf.squeeze, always pass the squeeze
+// dimension explicitly to avoid surprises.
+//
+// For the most part, the mapping between Proto field types and
+// TensorFlow dtypes is straightforward. However, there are a few
+// special cases:
+//
+// - A proto field that contains a submessage or group can only be converted
+// to `DT_STRING` (the serialized submessage). This is to reduce the
+// complexity of the API. The resulting string can be used as input
+// to another instance of the decode_proto op.
+//
+// - TensorFlow lacks support for unsigned integers. The ops represent uint64
+// types as a `DT_INT64` with the same twos-complement bit pattern
+// (the obvious way). Unsigned int32 values can be represented exactly by
+// specifying type `DT_INT64`, or using twos-complement if the caller
+// specifies `DT_INT32` in the `output_types` attribute.
+//
+// The `descriptor_source` attribute selects a source of protocol
+// descriptors to consult when looking up `message_type`. This may be a
+// filename containing a serialized `FileDescriptorSet` message,
+// or the special value `local://`, in which case only descriptors linked
+// into the code will be searched; the filename can be on any filesystem
+// accessible to TensorFlow.
+//
+// You can build a `descriptor_source` file using the `--descriptor_set_out`
+// and `--include_imports` options to the protocol compiler `protoc`.
+//
+// The `local://` database only covers descriptors linked into the
+// code via C++ libraries, not Python imports. You can link in a proto descriptor
+// by creating a cc_library target with alwayslink=1.
+//
+// Both binary and text proto serializations are supported, and can be
+// chosen using the `format` attribute.
+//
+// Arguments:
+// bytes: Tensor of serialized protos with shape `batch_shape`.
+// message_type: Name of the proto message type to decode.
+// field_names: List of strings containing proto field names.
+// output_types: List of TF types to use for the respective field in field_names.
+//
+// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+// Each entry is the number of values found for the corresponding field.
+// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field.
+// `values[i]` has datatype `output_types[i]`
+// and shape `[batch_shape, max(sizes[...,i])]`.
+func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeProtoV2",
+ Input: []tf.Input{
+ bytes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ sizes = op.Output(idx)
+ if values, idx, err = makeOutputList(op, idx, "values"); err != nil {
+ scope.UpdateErr("DecodeProtoV2", err)
+ return
+ }
+ return sizes, values
+}
+
+// Creates an Optional variant with no value.
+func OptionalNone(scope *Scope) (optional tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "OptionalNone",
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns true if and only if the given Optional variant has a value.
+func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "OptionalHasValue",
+ Input: []tf.Input{
+ optional,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a dataset that executes a SQL query and emits rows of the result set.
+//
+// Arguments:
+// driver_name: The database type. Currently, the only supported type is 'sqlite'.
+// data_source_name: A connection string to connect to the database.
+// query: A SQL query to execute.
+//
+//
+func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "SqlDataset",
+ Input: []tf.Input{
+ driver_name, data_source_name, query,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns the value stored in an Optional variant or raises an error if none exists.
+func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "OptionalGetValue",
+ Input: []tf.Input{
+ optional,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+ scope.UpdateErr("OptionalGetValue", err)
+ return
+ }
+ return components
+}
+
+// Gets the next output from the given iterator as an Optional variant.
+func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (optional tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "IteratorGetNextAsOptional",
+ Input: []tf.Input{
+ iterator,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -31234,529 +31916,3 @@ func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
op := scope.AddOperation(opspec)
return op.Output(0)
}
-
-// Adjust the hue of one or more images.
-//
-// `images` is a tensor of at least 3 dimensions. The last dimension is
-// interpretted as channels, and must be three.
-//
-// The input image is considered in the RGB colorspace. Conceptually, the RGB
-// colors are first mapped into HSV. A delta is then applied all the hue values,
-// and then remapped back to RGB colorspace.
-//
-// Arguments:
-// images: Images to adjust. At least 3-D.
-// delta: A float delta to add to the hue.
-//
-// Returns The hue-adjusted image or images.
-func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AdjustHue",
- Input: []tf.Input{
- images, delta,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// BatchAttr is an optional argument to Batch.
-type BatchAttr func(optionalAttr)
-
-// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value.
-// If not specified, defaults to 10
-func BatchMaxEnqueuedBatches(value int64) BatchAttr {
- return func(m optionalAttr) {
- m["max_enqueued_batches"] = value
- }
-}
-
-// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value.
-// If not specified, defaults to <>
-func BatchAllowedBatchSizes(value []int64) BatchAttr {
- return func(m optionalAttr) {
- m["allowed_batch_sizes"] = value
- }
-}
-
-// BatchContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func BatchContainer(value string) BatchAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// BatchSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func BatchSharedName(value string) BatchAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// BatchBatchingQueue sets the optional batching_queue attribute to value.
-// If not specified, defaults to ""
-func BatchBatchingQueue(value string) BatchAttr {
- return func(m optionalAttr) {
- m["batching_queue"] = value
- }
-}
-
-// Batches all input tensors nondeterministically.
-//
-// When many instances of this Op are being run concurrently with the same
-// container/shared_name in the same device, some will output zero-shaped Tensors
-// and others will output Tensors of size up to max_batch_size.
-//
-// All Tensors in in_tensors are batched together (so, for example, labels and
-// features should be batched with a single instance of this operation.
-//
-// Each invocation of batch emits an `id` scalar which will be used to identify
-// this particular invocation when doing unbatch or its gradient.
-//
-// Each op which emits a non-empty batch will also emit a non-empty batch_index
-// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id,
-// start, and length of elements of each set of Tensors present in batched_tensors.
-//
-// Batched tensors are concatenated along the first dimension, and all tensors in
-// in_tensors must have the first dimension of the same size.
-//
-// in_tensors: The tensors to be batched.
-// num_batch_threads: Number of scheduling threads for processing batches of work.
-// Determines the number of batches processed in parallel.
-// max_batch_size: Batch sizes will never be bigger than this.
-// batch_timeout_micros: Maximum number of microseconds to wait before outputting
-// an incomplete batch.
-// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does
-// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad
-// batches up to one of those sizes. The entries must increase monotonically, and
-// the final entry must equal max_batch_size.
-// grad_timeout_micros: The timeout to use for the gradient. See Unbatch.
-// batched_tensors: Either empty tensors or a batch of concatenated Tensors.
-// batch_index: If out_tensors is non-empty, has information to invert it.
-// container: Controls the scope of sharing of this batch.
-// id: always contains a scalar with a unique ID for this invocation of Batch.
-// shared_name: Concurrently running instances of batch in the same device with the
-// same container and shared_name will batch their elements together. If left
-// empty, the op name will be used as the shared name.
-// T: the types of tensors to be batched.
-func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Batch",
- Input: []tf.Input{
- tf.OutputList(in_tensors),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil {
- scope.UpdateErr("Batch", err)
- return
- }
- batch_index = op.Output(idx)
- id = op.Output(idx)
- return batched_tensors, batch_index, id
-}
-
-// UnbatchAttr is an optional argument to Unbatch.
-type UnbatchAttr func(optionalAttr)
-
-// UnbatchContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func UnbatchContainer(value string) UnbatchAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// UnbatchSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func UnbatchSharedName(value string) UnbatchAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Reverses the operation of Batch for a single output Tensor.
-//
-// An instance of Unbatch either receives an empty batched_tensor, in which case it
-// asynchronously waits until the values become available from a concurrently
-// running instance of Unbatch with the same container and shared_name, or receives
-// a non-empty batched_tensor in which case it finalizes all other concurrently
-// running instances and outputs its own element from the batch.
-//
-// batched_tensor: The possibly transformed output of Batch. The size of the first
-// dimension should remain unchanged by the transformations for the operation to
-// work.
-// batch_index: The matching batch_index obtained from Batch.
-// id: The id scalar emitted by Batch.
-// unbatched_tensor: The Tensor corresponding to this execution.
-// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the
-// batched input tensor associated with a given invocation of the op.
-// container: Container to control resource sharing.
-// shared_name: Instances of Unbatch with the same container and shared_name are
-// assumed to possibly belong to the same batch. If left empty, the op name will
-// be used as the shared name.
-func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"timeout_micros": timeout_micros}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Unbatch",
- Input: []tf.Input{
- batched_tensor, batch_index, id,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad.
-type AvgPool3DGradAttr func(optionalAttr)
-
-// AvgPool3DGradDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format of the input and output data. With the
-// default format "NDHWC", the data is stored in the order of:
-// [batch, in_depth, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCDHW", the data storage order is:
-// [batch, in_channels, in_depth, in_height, in_width].
-// If not specified, defaults to "NDHWC"
-func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes gradients of average pooling function.
-//
-// Arguments:
-// orig_input_shape: The original input dimensions.
-// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
-// ksize: 1-D tensor of length 5. The size of the window for each dimension of
-// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
-// strides: 1-D tensor of length 5. The stride of the sliding window for each
-// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
-// padding: The type of padding algorithm to use.
-//
-// Returns The backprop for input.
-func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AvgPool3DGrad",
- Input: []tf.Input{
- orig_input_shape, grad,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample.
-type ParseSingleSequenceExampleAttr func(optionalAttr)
-
-// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
-//
-// value: A list of Ncontext_sparse types; the data types of data in
-// each context Feature given in context_sparse_keys.
-// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["context_sparse_types"] = value
- }
-}
-
-// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["feature_list_dense_types"] = value
- }
-}
-
-// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
-//
-// value: A list of Ncontext_dense shapes; the shapes of data in
-// each context Feature given in context_dense_keys.
-// The number of elements in the Feature corresponding to context_dense_key[j]
-// must always equal context_dense_shapes[j].NumEntries().
-// The shape of context_dense_values[j] will match context_dense_shapes[j].
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["context_dense_shapes"] = value
- }
-}
-
-// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
-//
-// value: A list of Nfeature_list_sparse types; the data types
-// of data in each FeatureList given in feature_list_sparse_keys.
-// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["feature_list_sparse_types"] = value
- }
-}
-
-// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
-//
-// value: A list of Nfeature_list_dense shapes; the shapes of
-// data in each FeatureList given in feature_list_dense_keys.
-// The shape of each Feature in the FeatureList corresponding to
-// feature_list_dense_key[j] must always equal
-// feature_list_dense_shapes[j].NumEntries().
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr {
- return func(m optionalAttr) {
- m["feature_list_dense_shapes"] = value
- }
-}
-
-// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors.
-//
-// Arguments:
-// serialized: A scalar containing a binary serialized SequenceExample proto.
-// feature_list_dense_missing_assumed_empty: A vector listing the
-// FeatureList keys which may be missing from the SequenceExample. If the
-// associated FeatureList is missing, it is treated as empty. By default,
-// any FeatureList not listed in this vector must exist in the SequenceExample.
-// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
-// The keys expected in the Examples' features associated with context_sparse
-// values.
-// context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
-// The keys expected in the SequenceExamples' context features associated with
-// dense values.
-// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
-// (scalars). The keys expected in the FeatureLists associated with sparse
-// values.
-// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
-// The keys expected in the SequenceExamples' feature_lists associated
-// with lists of dense values.
-// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
-// context_dense_defaults[j] provides default values
-// when the SequenceExample's context map lacks context_dense_key[j].
-// If an empty Tensor is provided for context_dense_defaults[j],
-// then the Feature context_dense_keys[j] is required.
-// The input type is inferred from context_dense_defaults[j], even when it's
-// empty. If context_dense_defaults[j] is not empty, its shape must match
-// context_dense_shapes[j].
-// debug_name: A scalar containing the name of the serialized proto.
-// May contain, for example, table key (descriptive) name for the
-// corresponding serialized proto. This is purely useful for debugging
-// purposes, and the presence of values here has no effect on the output.
-// May also be an empty scalar if no name is available.
-func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ParseSingleSequenceExample",
- Input: []tf.Input{
- serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
- scope.UpdateErr("ParseSingleSequenceExample", err)
- return
- }
- return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values
-}
-
-// UnbatchGradAttr is an optional argument to UnbatchGrad.
-type UnbatchGradAttr func(optionalAttr)
-
-// UnbatchGradContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func UnbatchGradContainer(value string) UnbatchGradAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// UnbatchGradSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func UnbatchGradSharedName(value string) UnbatchGradAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Gradient of Unbatch.
-//
-// Acts like Batch but using the given batch_index index of batching things as they
-// become available. This ensures that the gradients are propagated back in the
-// same session which did the forward pass.
-//
-// original_input: The input to the Unbatch operation this is the gradient of.
-// batch_index: The batch_index given to the Unbatch operation this is the gradient
-// of.
-// grad: The downstream gradient.
-// id: The id scalar emitted by Batch.
-// batched_grad: The return value, either an empty tensor or the batched gradient.
-// container: Container to control resource sharing.
-// shared_name: Instances of UnbatchGrad with the same container and shared_name
-// are assumed to possibly belong to the same batch. If left empty, the op name
-// will be used as the shared name.
-func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UnbatchGrad",
- Input: []tf.Input{
- original_input, batch_index, grad, id,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// DecodeWavAttr is an optional argument to DecodeWav.
-type DecodeWavAttr func(optionalAttr)
-
-// DecodeWavDesiredChannels sets the optional desired_channels attribute to value.
-//
-// value: Number of sample channels wanted.
-// If not specified, defaults to -1
-func DecodeWavDesiredChannels(value int64) DecodeWavAttr {
- return func(m optionalAttr) {
- m["desired_channels"] = value
- }
-}
-
-// DecodeWavDesiredSamples sets the optional desired_samples attribute to value.
-//
-// value: Length of audio requested.
-// If not specified, defaults to -1
-func DecodeWavDesiredSamples(value int64) DecodeWavAttr {
- return func(m optionalAttr) {
- m["desired_samples"] = value
- }
-}
-
-// Decode a 16-bit PCM WAV file to a float tensor.
-//
-// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float.
-//
-// When desired_channels is set, if the input contains fewer channels than this
-// then the last channel will be duplicated to give the requested number, else if
-// the input has more channels than requested then the additional channels will be
-// ignored.
-//
-// If desired_samples is set, then the audio will be cropped or padded with zeroes
-// to the requested length.
-//
-// The first output contains a Tensor with the content of the audio samples. The
-// lowest dimension will be the number of channels, and the second will be the
-// number of samples. For example, a ten-sample-long stereo WAV file should give an
-// output shape of [10, 2].
-//
-// Arguments:
-// contents: The WAV-encoded audio, usually from a file.
-//
-// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header.
-func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DecodeWav",
- Input: []tf.Input{
- contents,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 7ceba3903d..87e6107c2d 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -305,6 +305,19 @@ tf_java_test(
],
)
+tf_java_test(
+ name = "ZerosTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.ZerosTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
filegroup(
name = "processor_test_resources",
srcs = glob([
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index 7b92be6d38..516655040b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -17,40 +17,54 @@ package org.tensorflow;
import java.util.HashMap;
import java.util.Map;
+
import org.tensorflow.types.UInt8;
/** Represents the type of elements in a {@link Tensor} as an enum. */
public enum DataType {
/** 32-bit single precision floating point. */
- FLOAT(1),
+ FLOAT(1, 4),
/** 64-bit double precision floating point. */
- DOUBLE(2),
+ DOUBLE(2, 8),
/** 32-bit signed integer. */
- INT32(3),
+ INT32(3, 4),
/** 8-bit unsigned integer. */
- UINT8(4),
+ UINT8(4, 1),
/**
* A sequence of bytes.
*
* <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes.
*/
- STRING(7),
+ STRING(7, -1),
/** 64-bit signed integer. */
- INT64(9),
+ INT64(9, 8),
/** Boolean. */
- BOOL(10);
+ BOOL(10, 1);
private final int value;
+
+ private final int byteSize;
- // The integer value must match the corresponding TF_* value in the TensorFlow C API.
- DataType(int value) {
+ /**
+ * @param value must match the corresponding TF_* value in the TensorFlow C API.
+ * @param byteSize size of an element of this type, in bytes, -1 if unknown
+ */
+ DataType(int value, int byteSize) {
this.value = value;
+ this.byteSize = byteSize;
+ }
+
+ /**
+ * Returns the size of an element of this type, in bytes, or -1 if element size is variable.
+ */
+ public int byteSize() {
+ return byteSize;
}
/** Corresponding value of the TF_DataType enum in the TensorFlow C API. */
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java
index 73324f23e6..a660d25f98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java
@@ -185,11 +185,20 @@ public final class Session implements AutoCloseable {
return this;
}
- /** Makes {@link #run()} return the Tensor referred to by {@code output}. */
+ /**
+ * Makes {@link #run()} return the Tensor referred to by {@code output}.
+ */
public Runner fetch(Output<?> output) {
outputs.add(output);
return this;
}
+
+ /**
+ * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
+ */
+ public Runner fetch(Operand<?> operand) {
+ return fetch(operand.asOutput());
+ }
/**
* Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
@@ -209,6 +218,13 @@ public final class Session implements AutoCloseable {
targets.add(operation);
return this;
}
+
+ /**
+ * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s.
+ */
+ public Runner addTarget(Operand<?> operand) {
+ return addTarget(operand.asOutput().op());
+ }
/**
* (Experimental method): set options (typically for debugging) for this run.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 24a3775db6..8987253768 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -595,20 +595,11 @@ public final class Tensor<T> implements AutoCloseable {
}
private static int elemByteSize(DataType dataType) {
- switch (dataType) {
- case FLOAT:
- case INT32:
- return 4;
- case DOUBLE:
- case INT64:
- return 8;
- case BOOL:
- case UINT8:
- return 1;
- case STRING:
+ int size = dataType.byteSize();
+ if (size < 0) {
throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
}
- throw new IllegalArgumentException("DataType " + dataType + " is not supported yet");
+ return size;
}
private static void throwExceptionIfNotByteOfByteArrays(Object array) {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index de4049f66b..00b6726be3 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -15,11 +15,15 @@ limitations under the License.
package org.tensorflow.op.core;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+import java.nio.charset.Charset;
+
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
@@ -32,25 +36,82 @@ import org.tensorflow.op.annotation.Operator;
/** An operator producing a constant value. */
@Operator
public final class Constant<T> extends PrimitiveOp implements Operand<T> {
+
/**
- * Create a constant from a Java object.
+ * Creates a constant containing a single {@code int} element.
*
- * <p>The argument {@code object} is first converted into a Tensor using {@link
- * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
- * provided. For example:
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return an integer constant
+ */
+ public static Constant<Integer> create(Scope scope, int data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code int} elements.
*
- * <pre>{@code
- * Constant.create(scope, 7); // returns a constant scalar tensor 7
- * }</pre>
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code int} elements.
*
* @param scope is a scope used to add the underlying operation.
- * @param object a Java object representing the constant.
- * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
*/
- public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
- try (Tensor<T> value = Tensor.create(object, type)) {
- return createWithTensor(scope, value);
- }
+ public static Constant<Integer> create(Scope scope, int[][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][][][] data) {
+ return create(scope, data, Integer.class);
}
/**
@@ -64,6 +125,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return an integer constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) {
@@ -73,6 +135,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code float} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a float constant
+ */
+ public static Constant<Float> create(Scope scope, float data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
* Create a {@link DataType#FLOAT} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -83,6 +222,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a float constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) {
@@ -92,6 +232,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code double} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a double constant
+ */
+ public static Constant<Double> create(Scope scope, double data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
* Create a {@link DataType#DOUBLE} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -102,6 +319,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a double constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) {
@@ -111,6 +329,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code long} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a long constant
+ */
+ public static Constant<Long> create(Scope scope, long data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
* Create a {@link DataType#INT64} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -121,6 +416,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a long constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) {
@@ -130,6 +426,174 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code boolean} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a boolean constant
+ */
+ public static Constant<Boolean> create(Scope scope, boolean data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a {@code String} constant using the default, UTF-8 encoding.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The string to put into the new constant.
+ * @return a string constant
+ */
+ public static Constant<String> create(Scope scope, String data) {
+ return create(scope, data, UTF_8);
+ }
+
+ /**
+ * Creates a {@code String} constant using a specified encoding.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new constant.
+ * @return a string constant
+ */
+ public static Constant<String> create(Scope scope, String data, Charset charset) {
+ try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
+ return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class));
+ }
+ }
+
+ /**
+ * Creates a constant containing a single {@code String} element, represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
* Create a constant with data from the given buffer.
*
* <p>Creates a Constant with the provided shape of any type where the constant data has been
@@ -141,6 +605,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param type the tensor datatype.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a constant of type `type`
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
@@ -150,6 +615,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
}
+ /**
+ * Create a constant from a Java object.
+ *
+ * <p>The argument {@code object} is first converted into a Tensor using {@link
+ * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
+ * provided. For example:
+ *
+ * <pre>{@code
+ * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix
+ * }</pre>
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param object a Java object representing the constant.
+ * @return a constant of type `type`
+ * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ */
+ public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(object, type)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) {
return new Constant<T>(
scope
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
new file mode 100644
index 0000000000..b7c6beb9bc
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.op.core;
+
+import java.nio.ByteBuffer;
+
+import org.tensorflow.DataType;
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * An operator creating a constant initialized with zeros of the shape given by `dims`.
+ *
+ * <p>For example, the following expression
+ * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre>
+ * is the equivalent of
+ * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre>
+ *
+ * @param <T> constant type
+ */
+@Operator
+public class Zeros<T> implements Op, Operand<T> {
+
+ /**
+ * Creates a zeroed tensor given its type and shape.
+ *
+ * @param scope is a scope used to add the underlying operation
+ * @param dims a 1-D operand that represents the shape of the output tensor
+ * @param type the output tensor datatype
+ * @return a constant tensor initialized with zeros
+ * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
+ */
+ public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) {
+ Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros"
+ int zeroSize = DataType.fromClass(type).byteSize();
+ if (zeroSize < 0) {
+ throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros");
+ }
+ Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize));
+ return new Zeros<T>(Fill.create(childScope, dims, zero));
+ }
+
+ @Override
+ public Output<T> asOutput() {
+ return fill.asOutput();
+ }
+
+ private final Fill<T> fill;
+
+ private Zeros(Fill<T> fill) {
+ this.fill = fill;
+ }
+}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index ca54214e06..7d3b26de8d 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.op.core;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
@@ -26,6 +27,7 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -37,6 +39,20 @@ import org.tensorflow.op.Scope;
@RunWith(JUnit4.class)
public class ConstantTest {
private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createInt() {
+ int value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Integer> op = Constant.create(scope, value);
+ try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) {
+ assertEquals(value, result.intValue());
+ }
+ }
+ }
@Test
public void createIntBuffer() {
@@ -47,10 +63,24 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput())
- .run().get(0).expect(Integer.class);
- int[] actual = new int[ints.length];
- assertArrayEquals(ints, result.copyTo(actual));
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[] actual = new int[ints.length];
+ assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual));
+ }
+ }
+ }
+
+ @Test
+ public void createFloat() {
+ float value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Float> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Float.class).floatValue(), 0.0f);
+ }
}
}
@@ -63,9 +93,24 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
- Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- float[] actual = new float[floats.length];
- assertArrayEquals(floats, result.copyTo(actual), EPSILON);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ float[] actual = new float[floats.length];
+ assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void createDouble() {
+ double value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Double> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Double.class).doubleValue(), 0.0);
+ }
}
}
@@ -78,9 +123,24 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
- Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- double[] actual = new double[doubles.length];
- assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void createLong() {
+ long value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Long> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Long.class).longValue());
+ }
}
}
@@ -93,15 +153,29 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
- Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- long[] actual = new long[longs.length];
- assertArrayEquals(longs, result.copyTo(actual));
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ long[] actual = new long[longs.length];
+ assertArrayEquals(longs, result.expect(Long.class).copyTo(actual));
+ }
}
}
@Test
- public void createStringBuffer() throws IOException {
+ public void createBoolean() {
+ boolean value = true;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Boolean> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Boolean.class).booleanValue());
+ }
+ }
+ }
+ @Test
+ public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
long[] shape = {};
@@ -124,8 +198,9 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
- Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class);
- assertArrayEquals(data, result.bytesValue());
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertArrayEquals(data, result.expect(String.class).bytesValue());
+ }
}
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
new file mode 100644
index 0000000000..cf3910b594
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
@@ -0,0 +1,165 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.util.List;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.op.Scope;
+import org.tensorflow.types.UInt8;
+
+@RunWith(JUnit4.class)
+public class ZerosTest {
+ private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createIntZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createFloatZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0f, actual[i][j], EPSILON);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createDoubleZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0, actual[i][j], EPSILON);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createLongZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0L, actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createBooleanZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertFalse(actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createUInt8Zeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]);
+ result.copyTo(actual);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void cannotCreateStringZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros.create(scope, Constant.create(scope, shape), String.class);
+ }
+ }
+
+ @Test
+ public void operationsComposingZerosAreCorrectlyNamed() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class);
+ List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
+ }
+ }
+}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 83b82aa0cc..456f007348 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -835,8 +835,10 @@ py_library(
deps = [
":c_api_util",
":control_flow_util",
+ ":cpp_shape_inference_proto_py",
":device",
":dtypes",
+ ":error_interpolation",
":op_def_registry",
":platform",
":registry",
@@ -3216,14 +3218,18 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
+ "training/checkpoint_management.py",
"training/saveable_object.py",
+ "training/saver.py",
"training/training_util.py",
],
),
srcs_version = "PY2AND3",
deps = [
+ "saver",
":array_ops",
":array_ops_gen",
+ ":checkpoint_management",
":checkpoint_ops_gen",
":client",
":control_flow_ops",
@@ -3235,24 +3241,20 @@ py_library(
":framework_ops",
":gradients",
":init_ops",
- ":distribute",
":io_ops",
- ":io_ops_gen",
":layers_base",
- ":lib",
":lookup_ops",
":math_ops",
":platform",
- ":protos_all_py",
":pywrap_tensorflow",
":random_ops",
":resource_variable_ops",
":resources",
- ":saveable_object",
":sdca_ops",
+ ":session",
":sparse_ops",
+ ":sparse_tensor",
":state_ops",
- ":string_ops",
":summary",
":training_ops_gen",
":training_util",
@@ -3262,6 +3264,7 @@ py_library(
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@@ -3279,6 +3282,52 @@ py_library(
)
py_library(
+ name = "checkpoint_management",
+ srcs = ["training/checkpoint_management.py"],
+ deps = [
+ ":errors",
+ ":lib",
+ ":platform",
+ ":protos_all_py",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_library(
+ name = "saver",
+ srcs = ["training/saver.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":checkpoint_management",
+ ":constant_op",
+ ":control_flow_ops",
+ ":device",
+ ":errors",
+ ":framework",
+ ":framework_ops",
+ ":io_ops",
+ ":io_ops_gen",
+ ":platform",
+ ":pywrap_tensorflow",
+ ":resource_variable_ops",
+ ":saveable_object",
+ ":session",
+ ":state_ops",
+ ":string_ops",
+ ":training_util",
+ ":util",
+ ":variables",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:base",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "device_util",
srcs = ["training/device_util.py"],
srcs_version = "PY2AND3",
@@ -4388,6 +4437,42 @@ cuda_py_test(
tags = ["multi_gpu"],
)
+cuda_py_test(
+ name = "checkpoint_management_test",
+ size = "small",
+ srcs = [
+ "training/checkpoint_management_test.py",
+ ],
+ additional_deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":data_flow_ops",
+ ":errors",
+ ":gradients",
+ ":math_ops",
+ ":nn_grad",
+ ":nn_ops",
+ ":saver_test_utils",
+ ":partitioned_variables",
+ ":platform",
+ ":platform_test",
+ ":pywrap_tensorflow",
+ ":random_ops",
+ ":resource_variable_ops",
+ ":sparse_ops",
+ ":summary",
+ ":training",
+ ":util",
+ ":variable_scope",
+ ":variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
py_test(
name = "saver_large_variable_test",
size = "medium",
@@ -4454,6 +4539,7 @@ tf_py_test(
srcs = ["training/supervisor_test.py"],
additional_deps = [
":array_ops",
+ ":checkpoint_management",
":client_testlib",
":errors",
":framework",
@@ -4461,6 +4547,7 @@ tf_py_test(
":io_ops",
":parsing_ops",
":platform",
+ ":saver",
":summary",
":training",
":variables",
@@ -4574,10 +4661,13 @@ py_test(
tags = ["notsan"], # b/67945581
deps = [
":array_ops",
+ ":checkpoint_management",
":client_testlib",
":control_flow_ops",
":errors",
":framework_for_generated_wrappers",
+ ":resource_variable_ops",
+ ":saver",
":session",
":state_ops",
":summary",
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 247ea7349d..af47ff69c9 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 1)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 6)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index b66b87ce6c..23c98247bf 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -329,6 +329,8 @@ cuda_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -350,6 +352,8 @@ cuda_py_test(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
"//tensorflow/python/compat:compat",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
],
grpc_enabled = True,
)
@@ -381,3 +385,22 @@ tf_py_test(
"no_windows",
],
)
+
+cuda_py_test(
+ name = "optional_ops_test",
+ size = "small",
+ srcs = ["optional_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index 25269dc810..4f7fd3566e 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class FilesystemCacheDatasetTest(test.TestCase):
+class FileCacheDatasetTest(test.TestCase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index b434fa7334..352424514e 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import os
import warnings
@@ -46,7 +47,9 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
@@ -788,5 +791,98 @@ class IteratorTest(test.TestCase):
val += 1
+class IteratorCheckpointingTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreOneShotIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
+ math_ops.square).batch(2)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ with self.test_session() as sess:
+ self.assertAllEqual([1, 4], get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestoreMultipleIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+ dataset = dataset.map(math_ops.square).batch(2)
+ iterator_1 = dataset.make_one_shot_iterator()
+ get_next_1 = iterator_1.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_1.get_next())
+ iterator_2 = dataset.make_one_shot_iterator()
+ get_next_2 = iterator_2.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_2.get_next())
+ dataset_2 = dataset_ops.Dataset.range(10)
+ iterator_3 = dataset_2.make_one_shot_iterator()
+ get_next_3 = iterator_3.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_3.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(
+ iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
+ with self.test_session() as sess:
+ self.assertAllEqual([1, 4], get_next_1())
+ self.assertAllEqual(0, get_next_3())
+ self.assertAllEqual(1, get_next_3())
+ self.assertAllEqual(2, get_next_3())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual([9, 16], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ self.assertAllEqual([9, 16], get_next_1())
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testRestoreExhaustedIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.range(3)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ with self.test_session() as sess:
+ self.assertAllEqual(0, get_next())
+ self.assertAllEqual(1, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual(2, get_next())
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ self.assertAllEqual(2, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ checkpoint.restore(save_path).run_restore_ops(sess)
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
+
+ def testRestoreInReconstructedIteratorInitializable(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.range(10)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ for i in range(5):
+ with self.test_session() as sess:
+ checkpoint.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory)).initialize_or_restore(sess)
+ for j in range(2):
+ self.assertEqual(i * 2 + j, sess.run(get_next))
+ checkpoint.save(file_prefix=checkpoint_prefix)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index f7d7d085c9..579096f880 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -123,13 +123,11 @@ class ListFilesDatasetOpTest(test.TestCase):
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
- sess.run(
- itr.initializer,
- feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError, 'No files matched pattern: '):
+ sess.run(
+ itr.initializer,
+ feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
def testSimpleDirectoryInitializer(self):
filenames = ['a', 'b', 'c']
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
new file mode 100644
index 0000000000..a32527af8d
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -0,0 +1,186 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the Optional data type wrapper."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class OptionalTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFromValue(self):
+ opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
+ self.assertEqual(dtypes.float32, opt.output_types)
+ self.assertEqual([], opt.output_shapes)
+ self.assertEqual(ops.Tensor, opt.output_classes)
+ self.assertTrue(self.evaluate(opt.has_value()))
+ self.assertEqual(37.0, self.evaluate(opt.get_value()))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFromStructuredValue(self):
+ opt = optional_ops.Optional.from_value({
+ "a": constant_op.constant(37.0),
+ "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
+ })
+ self.assertEqual({
+ "a": dtypes.float32,
+ "b": (dtypes.string, dtypes.string)
+ }, opt.output_types)
+ self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes)
+ self.assertEqual({
+ "a": ops.Tensor,
+ "b": (ops.Tensor, ops.Tensor)
+ }, opt.output_classes)
+ self.assertTrue(self.evaluate(opt.has_value()))
+ self.assertEqual({
+ "a": 37.0,
+ "b": ([b"Foo"], b"Bar")
+ }, self.evaluate(opt.get_value()))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFromSparseTensor(self):
+ st_0 = sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0], dtype=np.int64),
+ dense_shape=np.array([1]))
+ st_1 = sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 1]]),
+ values=np.array([-1., 1.], dtype=np.float32),
+ dense_shape=np.array([2, 2]))
+ opt = optional_ops.Optional.from_value((st_0, st_1))
+ self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types)
+ self.assertEqual(([1], [2, 2]), opt.output_shapes)
+ self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor),
+ opt.output_classes)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFromNone(self):
+ opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(),
+ dtypes.float32, ops.Tensor)
+ self.assertEqual(dtypes.float32, opt.output_types)
+ self.assertEqual([], opt.output_shapes)
+ self.assertEqual(ops.Tensor, opt.output_classes)
+ self.assertFalse(self.evaluate(opt.has_value()))
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(opt.get_value())
+
+ def testStructureMismatchError(self):
+ tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar())
+ tuple_output_types = (dtypes.float32, dtypes.float32)
+ tuple_output_classes = (ops.Tensor, ops.Tensor)
+
+ dict_output_shapes = {
+ "a": tensor_shape.scalar(),
+ "b": tensor_shape.scalar()
+ }
+ dict_output_types = {"a": dtypes.float32, "b": dtypes.float32}
+ dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor}
+
+ with self.assertRaises(TypeError):
+ optional_ops.Optional.none_from_structure(
+ tuple_output_shapes, tuple_output_types, dict_output_classes)
+
+ with self.assertRaises(TypeError):
+ optional_ops.Optional.none_from_structure(
+ tuple_output_shapes, dict_output_types, tuple_output_classes)
+
+ with self.assertRaises(TypeError):
+ optional_ops.Optional.none_from_structure(
+ dict_output_shapes, tuple_output_types, tuple_output_classes)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCopyToGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with ops.device("/cpu:0"):
+ optional_with_value = optional_ops.Optional.from_value(
+ (constant_op.constant(37.0), constant_op.constant("Foo"),
+ constant_op.constant(42)))
+ optional_none = optional_ops.Optional.none_from_structure(
+ tensor_shape.scalar(), dtypes.float32, ops.Tensor)
+
+ with ops.device("/gpu:0"):
+ gpu_optional_with_value = optional_ops._OptionalImpl(
+ array_ops.identity(optional_with_value._variant_tensor),
+ optional_with_value.output_shapes, optional_with_value.output_types,
+ optional_with_value.output_classes)
+ gpu_optional_none = optional_ops._OptionalImpl(
+ array_ops.identity(optional_none._variant_tensor),
+ optional_none.output_shapes, optional_none.output_types,
+ optional_none.output_classes)
+
+ gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
+ gpu_optional_with_value_values = gpu_optional_with_value.get_value()
+
+ gpu_optional_none_has_value = gpu_optional_none.has_value()
+
+ self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
+ self.assertEqual((37.0, b"Foo", 42),
+ self.evaluate(gpu_optional_with_value_values))
+ self.assertFalse(self.evaluate(gpu_optional_none_has_value))
+
+ def testIteratorGetNextAsOptional(self):
+ ds = dataset_ops.Dataset.range(3)
+ iterator = ds.make_initializable_iterator()
+ next_elem = iterator_ops.get_next_as_optional(iterator)
+ self.assertTrue(isinstance(next_elem, optional_ops.Optional))
+ self.assertEqual(ds.output_types, next_elem.output_types)
+ self.assertEqual(ds.output_shapes, next_elem.output_shapes)
+ self.assertEqual(ds.output_classes, next_elem.output_classes)
+ elem_has_value_t = next_elem.has_value()
+ elem_value_t = next_elem.get_value()
+ with self.test_session() as sess:
+ # Before initializing the iterator, evaluating the optional fails with
+ # a FailedPreconditionError.
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_has_value_t)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_value_t)
+
+ # For each element of the dataset, assert that the optional evaluates to
+ # the expected value.
+ sess.run(iterator.initializer)
+ for i in range(3):
+ elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
+ self.assertTrue(elem_has_value)
+ self.assertEqual(i, elem_value)
+
+ # After exhausting the iterator, `next_elem.has_value()` will evaluate to
+ # false, and attempting to get the value will fail.
+ for _ in range(2):
+ self.assertFalse(sess.run(elem_has_value_t))
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(elem_value_t)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index f15eb6310f..50ba5f403e 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -11,6 +11,7 @@ py_library(
deps = [
":iterator_ops",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
@@ -19,6 +20,7 @@ py_library(
"//tensorflow/python:random_seed",
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
@@ -50,14 +52,33 @@ py_library(
srcs = ["iterator_ops.py"],
srcs_version = "PY2AND3",
deps = [
+ ":optional_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:saver",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:base",
+ ],
+)
+
+py_library(
+ name = "optional_ops",
+ srcs = ["optional_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
],
)
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 88de4b588c..6cda2a77cc 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -39,10 +39,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -644,17 +646,34 @@ class Dataset(object):
Returns:
Dataset: A `Dataset` of strings corresponding to file names.
"""
- if shuffle is None:
- shuffle = True
- matching_files = gen_io_ops.matching_files(file_pattern)
- dataset = Dataset.from_tensor_slices(matching_files)
- if shuffle:
- # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
- # list of files might be empty.
- buffer_size = math_ops.maximum(
- array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
- dataset = dataset.shuffle(buffer_size, seed=seed)
- return dataset
+ with ops.name_scope("list_files"):
+ if shuffle is None:
+ shuffle = True
+ file_pattern = ops.convert_to_tensor(
+ file_pattern, dtype=dtypes.string, name="file_pattern")
+ matching_files = gen_io_ops.matching_files(file_pattern)
+
+ # Raise an exception if `file_pattern` does not match any files.
+ condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
+ name="match_not_empty")
+
+ message = math_ops.add(
+ "No files matched pattern: ",
+ string_ops.reduce_join(file_pattern, separator=", "), name="message")
+
+ assert_not_empty = control_flow_ops.Assert(
+ condition, [message], summarize=1, name="assert_not_empty")
+ with ops.control_dependencies([assert_not_empty]):
+ matching_files = array_ops.identity(matching_files)
+
+ dataset = Dataset.from_tensor_slices(matching_files)
+ if shuffle:
+ # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
+ # list of files might be empty.
+ buffer_size = math_ops.maximum(
+ array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
+ dataset = dataset.shuffle(buffer_size, seed=seed)
+ return dataset
def repeat(self, count=None):
"""Repeats this dataset `count` times.
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 494df178df..83c541c2f7 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -21,6 +21,7 @@ import threading
import warnings
from tensorflow.python.compat import compat
+from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
@@ -30,6 +31,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util.tf_export import tf_export
@@ -65,7 +68,7 @@ def _device_stack_is_empty():
@tf_export("data.Iterator")
-class Iterator(object):
+class Iterator(checkpointable.CheckpointableBase):
"""Represents the state of iterating through a `Dataset`."""
def __init__(self, iterator_resource, initializer, output_types,
@@ -464,6 +467,13 @@ class Iterator(object):
"""
return self._output_types
+ def _gather_saveables_for_checkpoint(self):
+
+ def _saveable_factory(name):
+ return _IteratorSaveable(self._iterator_resource, name)
+
+ return {"ITERATOR": _saveable_factory}
+
_uid_counter = 0
_uid_lock = threading.Lock()
@@ -477,7 +487,7 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
-class EagerIterator(object):
+class EagerIterator(checkpointable.CheckpointableBase):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
def __init__(self, dataset):
@@ -610,3 +620,56 @@ class EagerIterator(object):
"""
del name
return self._next_internal()
+
+ def _gather_saveables_for_checkpoint(self):
+
+ def _saveable_factory(name):
+ return _IteratorSaveable(self._resource, name)
+
+ return {"ITERATOR": _saveable_factory}
+
+
+# TODO(b/71645805): Expose checkpointable stateful objects from dataset
+# attributes(potential).
+class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource, name):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
+ specs = [
+ BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
+ ]
+ # pylint: disable=protected-access
+ super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+def get_next_as_optional(iterator):
+ """Returns an `Optional` that contains the next value from the iterator.
+
+ If `iterator` has reached the end of the sequence, the returned `Optional`
+ will have no value.
+
+ Args:
+ iterator: A `tf.data.Iterator` object.
+
+ Returns:
+ An `Optional` object representing the next value from the iterator (if it
+ has one) or no value.
+ """
+ # pylint: disable=protected-access
+ return optional_ops._OptionalImpl(
+ gen_dataset_ops.iterator_get_next_as_optional(
+ iterator._iterator_resource,
+ output_types=nest.flatten(
+ sparse.as_dense_types(iterator.output_types,
+ iterator.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(iterator.output_shapes,
+ iterator.output_classes))),
+ output_shapes=iterator.output_shapes,
+ output_types=iterator.output_types,
+ output_classes=iterator.output_classes)
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
new file mode 100644
index 0000000000..1d3007ef76
--- /dev/null
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -0,0 +1,209 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An Optional type for representing potentially missing values."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class Optional(object):
+ """Wraps a nested structure of tensors that may/may not be present at runtime.
+
+ An `Optional` can represent the result of an operation that may fail as a
+ value, rather than raising an exception and halting execution. For example,
+ @{tf.contrib.data.get_next_as_optional} returns an `Optional` that either
+ contains the next value from a @{tf.data.Iterator} if one exists, or a "none"
+ value that indicates the end of the sequence has been reached.
+ """
+
+ @abc.abstractmethod
+ def has_value(self, name=None):
+ """Returns a tensor that evaluates to `True` if this optional has a value.
+
+ Args:
+ name: (Optional.) A name for the created operation.
+
+ Returns:
+ A scalar `tf.Tensor` of type `tf.bool`.
+ """
+ raise NotImplementedError("Optional.has_value()")
+
+ @abc.abstractmethod
+ def get_value(self, name=None):
+ """Returns a nested structure of values wrapped by this optional.
+
+ If this optional does not have a value (i.e. `self.has_value()` evaluates
+ to `False`), this operation will raise @{tf.errors.InvalidArgumentError}
+ at runtime.
+
+ Args:
+ name: (Optional.) A name for the created operation.
+
+ Returns:
+ A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
+ """
+ raise NotImplementedError("Optional.get_value()")
+
+ @abc.abstractproperty
+ def output_classes(self):
+ """Returns the class of each component of this optional.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of this optional.
+ """
+ raise NotImplementedError("Optional.output_classes")
+
+ @abc.abstractproperty
+ def output_shapes(self):
+ """Returns the shape of each component of this optional.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of this optional.
+ """
+ raise NotImplementedError("Optional.output_shapes")
+
+ @abc.abstractproperty
+ def output_types(self):
+ """Returns the type of each component of this optional.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of this optional.
+ """
+ raise NotImplementedError("Optional.output_types")
+
+ @staticmethod
+ def from_value(value):
+ """Returns an `Optional` that wraps the given value.
+
+ Args:
+ value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
+
+ Returns:
+ An `Optional` that wraps `value`.
+ """
+ # TODO(b/110122868): Consolidate this destructuring logic with the
+ # similar code in `Dataset.from_tensors()`.
+ with ops.name_scope("optional") as scope:
+ with ops.name_scope("value"):
+ value = nest.pack_sequence_as(value, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(value))
+ ])
+
+ encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
+ output_classes = sparse.get_classes(value)
+ output_shapes = nest.pack_sequence_as(
+ value, [t.get_shape() for t in nest.flatten(value)])
+ output_types = nest.pack_sequence_as(
+ value, [t.dtype for t in nest.flatten(value)])
+
+ return _OptionalImpl(
+ gen_dataset_ops.optional_from_value(encoded_value, name=scope),
+ output_shapes, output_types, output_classes)
+
+ @staticmethod
+ def none_from_structure(output_shapes, output_types, output_classes):
+ """Returns an `Optional` that has no value.
+
+ NOTE: This method takes arguments that define the structure of the value
+ that would be contained in the returned `Optional` if it had a value.
+
+ Args:
+ output_shapes: A nested structure of `tf.TensorShape` objects
+ corresponding to each component of this optional.
+ output_types: A nested structure of `tf.DType` objects corresponding to
+ each component of this optional.
+ output_classes: A nested structure of Python `type` objects corresponding
+ to each component of this optional.
+
+ Returns:
+ An `Optional` that has no value.
+ """
+ return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes,
+ output_types, output_classes)
+
+
+class _OptionalImpl(Optional):
+ """Concrete implementation of `tf.contrib.data.Optional`.
+
+ NOTE(mrry): This implementation is kept private, to avoid defining
+ `Optional.__init__()` in the public API.
+ """
+
+ def __init__(self, variant_tensor, output_shapes, output_types,
+ output_classes):
+ # TODO(b/110122868): Consolidate the structure validation logic with the
+ # similar logic in `Iterator.from_structure()` and
+ # `Dataset.from_generator()`.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ output_shapes = nest.map_structure_up_to(
+ output_types, tensor_shape.as_shape, output_shapes)
+ nest.assert_same_structure(output_types, output_shapes)
+ nest.assert_same_structure(output_types, output_classes)
+ self._variant_tensor = variant_tensor
+ self._output_shapes = output_shapes
+ self._output_types = output_types
+ self._output_classes = output_classes
+
+ def has_value(self, name=None):
+ return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
+
+ def get_value(self, name=None):
+ # TODO(b/110122868): Consolidate the restructuring logic with similar logic
+ # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
+ with ops.name_scope(name, "OptionalGetValue",
+ [self._variant_tensor]) as scope:
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(
+ self._output_types,
+ gen_dataset_ops.optional_get_value(
+ self._variant_tensor,
+ name=scope,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._output_types,
+ self._output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes,
+ self._output_classes)))),
+ self._output_types, self._output_shapes, self._output_classes)
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 04c50dbafc..dab1ed43ca 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -34,13 +34,13 @@ class _TaskType(object):
EVALUATOR = "evaluator"
-_coordinator_context = threading.local()
+_worker_context = threading.local()
-def get_current_coordinator_context():
- """Returns the current coordinator context."""
+def get_current_worker_context():
+ """Returns the current task context."""
try:
- return _coordinator_context.current
+ return _worker_context.current
except AttributeError:
return None
@@ -86,13 +86,13 @@ def _get_num_workers(cluster_spec):
cluster_spec.as_dict().get(_TaskType.CHIEF, []))
-class _CoordinatorContext(object):
- """The coordinator context class.
+class _WorkerContext(object):
+ """The worker context class.
This context object provides configuration information for each task. One
- context manager with a coordinator context object will be created per
- invocation to the `worker_fn` where `get_current_coordinator_context` can be
- called to access the coordinator context object.
+ context manager with a worker context object will be created per
+ invocation to the `worker_fn` where `get_current_worker_context` can be called
+ to access the worker context object.
"""
def __init__(self,
@@ -102,7 +102,7 @@ class _CoordinatorContext(object):
between_graph=False,
rpc_layer="grpc",
worker_barrier=None):
- """Initialize the coordinator context object.
+ """Initialize the worker context object.
Args:
cluster_spec: a ClusterSpec object. It can be empty or None in the local
@@ -139,15 +139,15 @@ class _CoordinatorContext(object):
self._is_chief_node = self._is_chief()
def __enter__(self):
- old_context = get_current_coordinator_context()
+ old_context = get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.")
- _coordinator_context.current = self
+ _worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback):
- _coordinator_context.current = None
+ _worker_context.current = None
def _get_master_target(self):
"""Return the master target for a task."""
@@ -195,7 +195,7 @@ class _CoordinatorContext(object):
"""
if not self._worker_barrier:
raise ValueError(
- "`worker_barrier is not set in the coordinator context.`")
+ "`worker_barrier is not set in the worker context.`")
self._worker_barrier.wait()
@property
@@ -236,8 +236,8 @@ class _CoordinatorContext(object):
def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer,
worker_barrier):
- with _CoordinatorContext(cluster_spec, task_type, task_id, between_graph,
- rpc_layer, worker_barrier):
+ with _WorkerContext(cluster_spec, task_type, task_id, between_graph,
+ rpc_layer, worker_barrier):
worker_fn()
@@ -266,13 +266,13 @@ def run_distribute_coordinator(worker_fn,
this `worker_fn`.
The `worker_fn` defines the training logic and is called under a its own
- coordinator context which can be accessed to via
- `get_current_coordinator_context`. A coordinator context provides access to
- configurations for each task, e.g. the task_type, task_id, master target and
- so on. Since `worker_fn` will be called in a thread and possibly multiple
- times, caller should be careful when it accesses global data. For example, it
- is unsafe to define flags in a `worker_fn` or to define different environment
- variables for different `worker_fn`s.
+ worker context which can be accessed to via `get_current_worker_context`. A
+ worker context provides access to configurations for each task, e.g. the
+ task_type, task_id, master target and so on. Since `worker_fn` will be called
+ in a thread and possibly multiple times, caller should be careful when it
+ accesses global data. For example, it is unsafe to define flags in a
+ `worker_fn` or to define different environment variables for different
+ `worker_fn`s.
The `worker_fn` for the between-graph replication is defined as if there are
only one worker corresponding to the `worker_fn` and possibly ps jobs. It
@@ -287,7 +287,7 @@ def run_distribute_coordinator(worker_fn,
high-level APIs, to change a program to use this coordinator, wrap everything
in a the program after global data definitions such as commandline flag
definition into the `worker_fn` and get task-specific configurations from
- the coordinator context.
+ the worker context.
The `cluster_spec` can be either passed by the argument or parsed from the
"TF_CONFIG" envrionment variable. Example of a TF_CONFIG:
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 82fd823352..d7ffeb56a5 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -67,7 +67,7 @@ class DistributeCoordinatorTest(test.TestCase):
def setUp(self):
self._result_correct = 0
self._lock = threading.Lock()
- self._task_context = {}
+ self._worker_context = {}
@contextlib.contextmanager
def _test_session(self, target):
@@ -77,7 +77,7 @@ class DistributeCoordinatorTest(test.TestCase):
yield sess
def _in_graph_worker_fn(self):
- context = distribute_coordinator.get_current_coordinator_context()
+ context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
@@ -107,7 +107,7 @@ class DistributeCoordinatorTest(test.TestCase):
self.assertEqual(self._result_correct, 1)
def _between_graph_worker_fn(self):
- context = distribute_coordinator.get_current_coordinator_context()
+ context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
@@ -153,113 +153,113 @@ class DistributeCoordinatorTest(test.TestCase):
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
- def _dump_task_context(self):
- """Dumps the propoerties of each coordinator context.
+ def _dump_worker_context(self):
+ """Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
"""
- context = distribute_coordinator.get_current_coordinator_context()
+ context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
- if task_type not in self._task_context:
- self._task_context[task_type] = []
- while len(self._task_context[task_type]) <= task_id:
- self._task_context[task_type].append(None)
- self._task_context[task_type][task_id] = (context.master_target,
- context.num_workers,
- context.is_chief,
- context.distributed_mode)
+ if task_type not in self._worker_context:
+ self._worker_context[task_type] = []
+ while len(self._worker_context[task_type]) <= task_id:
+ self._worker_context[task_type].append(None)
+ self._worker_context[task_type][task_id] = (context.master_target,
+ context.num_workers,
+ context.is_chief,
+ context.distributed_mode)
def testBetweenGraphContext(self):
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context,
+ self._dump_worker_context,
cluster_spec=self._cluster_spec,
between_graph=True)
# There is only one type of task and there three such tasks.
- self.assertEqual(len(self._task_context), 1)
- self.assertTrue(WORKER in self._task_context)
- self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue(WORKER in self._worker_context)
+ self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
- self._task_context[WORKER][0],
+ self._worker_context[WORKER][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
self.assertEqual(
- self._task_context[WORKER][1],
+ self._worker_context[WORKER][1],
(_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
self.assertEqual(
- self._task_context[WORKER][2],
+ self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
def testInGraphContext(self):
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context,
+ self._dump_worker_context,
cluster_spec=self._cluster_spec,
between_graph=False)
# There is only a "None" task in the dumped task context.
- self.assertEqual(len(self._task_context), 1)
- self.assertTrue("None" in self._task_context)
- self.assertEqual(len(self._task_context["None"]), 1)
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue("None" in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
- self._task_context["None"][0],
+ self._worker_context["None"][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
def testLocalContext(self):
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context, cluster_spec=None, between_graph=True)
+ self._dump_worker_context, cluster_spec=None, between_graph=True)
# There is only a "None" task.
- self.assertEqual(len(self._task_context), 1)
- self.assertTrue("None" in self._task_context)
- self.assertEqual(len(self._task_context["None"]), 1)
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue("None" in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._task_context["None"][0], ("local", 0, True, False))
+ self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
cluster_spec = copy.deepcopy(self._cluster_spec)
cluster_spec[CHIEF] = ["fake_chief"]
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context,
+ self._dump_worker_context,
cluster_spec=cluster_spec,
between_graph=True,
rpc_layer="grpc")
# There are one CHIEF and three workers.
- self.assertEqual(len(self._task_context), 2)
- self.assertTrue(CHIEF in self._task_context)
- self.assertTrue(WORKER in self._task_context)
- self.assertEqual(len(self._task_context[CHIEF]), 1)
- self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+ self.assertEqual(len(self._worker_context), 2)
+ self.assertTrue(CHIEF in self._worker_context)
+ self.assertTrue(WORKER in self._worker_context)
+ self.assertEqual(len(self._worker_context[CHIEF]), 1)
+ self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._task_context[CHIEF][0],
+ self.assertEqual(self._worker_context[CHIEF][0],
("grpc://fake_chief", 4, True, True))
- self.assertEqual(self._task_context[WORKER][0],
+ self.assertEqual(self._worker_context[WORKER][0],
("grpc://" + _bytes_to_str(self._workers[0].target),
NUM_WORKERS + 1, False, True))
- self.assertEqual(self._task_context[WORKER][1],
+ self.assertEqual(self._worker_context[WORKER][1],
("grpc://" + _bytes_to_str(self._workers[1].target),
NUM_WORKERS + 1, False, True))
- self.assertEqual(self._task_context[WORKER][2],
+ self.assertEqual(self._worker_context[WORKER][2],
("grpc://" + _bytes_to_str(self._workers[2].target),
NUM_WORKERS + 1, False, True))
@@ -268,22 +268,24 @@ class DistributeCoordinatorTest(test.TestCase):
cluster_spec = copy.deepcopy(self._cluster_spec)
cluster_spec[EVALUATOR] = ["fake_evaluator"]
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context, cluster_spec=cluster_spec, between_graph=False)
+ self._dump_worker_context,
+ cluster_spec=cluster_spec,
+ between_graph=False)
# There are one "None" task and one EVALUATOR task.
- self.assertEqual(len(self._task_context), 2)
- self.assertTrue("None" in self._task_context)
- self.assertTrue(EVALUATOR in self._task_context)
- self.assertEqual(len(self._task_context["None"]), 1)
- self.assertEqual(len(self._task_context[EVALUATOR]), 1)
+ self.assertEqual(len(self._worker_context), 2)
+ self.assertTrue("None" in self._worker_context)
+ self.assertTrue(EVALUATOR in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
+ self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._task_context["None"][0],
+ self.assertEqual(self._worker_context["None"][0],
(_bytes_to_str(self._workers[0].target), 3, True, True))
- self.assertEqual(self._task_context[EVALUATOR][0],
+ self.assertEqual(self._worker_context[EVALUATOR][0],
("fake_evaluator", 3, False, True))
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index c59ad09bf1..5f60f62874 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -276,7 +276,7 @@ def implicit_grad(f):
def _get_arg_spec(f, params, param_args):
"""The positions of the parameters of f to be differentiated in param_args."""
try:
- args = tf_inspect.getargspec(f).args
+ args = tf_inspect.getfullargspec(f).args
except TypeError as e:
# TypeError can happen when f is a callable object.
if params is None:
@@ -591,9 +591,6 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
-_zeros_cache = context._TensorCache() # pylint: disable=protected-access
-
-
def _fast_fill(value, shape, dtype):
return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
@@ -611,10 +608,10 @@ def _zeros(shape, dtype):
device = ctx.device_name
cache_key = shape, dtype, device
- cached = _zeros_cache.get(cache_key)
+ cached = ctx.zeros_cache().get(cache_key)
if cached is None:
cached = _fast_fill(0, shape, dtype)
- _zeros_cache.put(cache_key, cached)
+ ctx.zeros_cache().put(cache_key, cached)
return cached
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 495a674526..c79294895b 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -91,6 +91,7 @@ class _EagerContext(threading.local):
self.summary_writer_resource = None
self.scalar_cache = {}
self.ones_rank_cache = _TensorCache()
+ self.zeros_cache = _TensorCache()
self.execution_mode = None
@@ -225,6 +226,24 @@ class Context(object):
"""
return self._rng.randint(0, _MAXINT32)
+ def _initialize_devices(self):
+ """Helper to initialize devices."""
+ # Store list of devices
+ self._context_devices = []
+ device_list = pywrap_tensorflow.TFE_ContextListDevices(
+ self._context_handle)
+ try:
+ self._num_gpus = 0
+ for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
+ dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
+ self._context_devices.append(pydev.canonical_name(dev_name))
+ dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
+ if dev_type == "GPU":
+ self._num_gpus += 1
+
+ finally:
+ pywrap_tensorflow.TF_DeleteDeviceList(device_list)
+
def _initialize_handle_and_devices(self):
"""Initialize handle and devices."""
with self._initialize_lock:
@@ -241,27 +260,48 @@ class Context(object):
opts, self._device_policy)
if self._execution_mode == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
- if self._server_def is not None:
- server_def_str = self._server_def.SerializeToString()
- pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str)
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
- # Store list of devices
- self._context_devices = []
- device_list = pywrap_tensorflow.TFE_ContextListDevices(
- self._context_handle)
- try:
- self._num_gpus = 0
- for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
- dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
- self._context_devices.append(pydev.canonical_name(dev_name))
- dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
- if dev_type == "GPU":
- self._num_gpus += 1
+ if self._server_def is not None:
+ server_def_str = self._server_def.SerializeToString()
+ pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
+ server_def_str)
- finally:
- pywrap_tensorflow.TF_DeleteDeviceList(device_list)
+ self._initialize_devices()
+
+ def _clear_caches(self):
+ self.scalar_cache().clear()
+ self.ones_rank_cache().flush()
+ self.zeros_cache().flush()
+
+ def set_server_def(self, server_def):
+ """Allow setting a server_def on the context.
+
+ When a server def is replaced, it effectively clears a bunch of caches
+ within the context. If you attempt to use a tensor object that was pointing
+ to a tensor on the remote device, it will raise an error.
+
+ Args:
+ server_def: A tensorflow::ServerDef proto.
+ Enables execution on remote devices.
+
+ Raises:
+ ValueError: if server_def is None.
+ """
+ if not server_def:
+ raise ValueError("server_def is None.")
+ if not self._context_handle:
+ self._server_def = server_def
+ else:
+ server_def_str = server_def.SerializeToString()
+ pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
+ server_def_str)
+
+ # Clear all the caches in case there are remote tensors in them.
+ self._clear_caches()
+
+ self._initialize_devices()
@property
def _handle(self):
@@ -324,6 +364,10 @@ class Context(object):
"""Per-device cache for scalars."""
return self._eager_context.ones_rank_cache
+ def zeros_cache(self):
+ """Per-device cache for scalars."""
+ return self._eager_context.zeros_cache
+
@property
def scope_name(self):
"""Returns scope name for the current thread."""
@@ -735,6 +779,10 @@ def export_run_metadata():
return context().export_run_metadata()
+def set_server_def(server_def):
+ context().set_server_def(server_def)
+
+
# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 99129c2537..29e234efd8 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -132,11 +132,23 @@ class CapturingGraph(ops.Graph):
op_def=None,
compute_shapes=True,
compute_device=True):
- # TODO(apassos) this should do some form of alias analysis as ops which
- # forward the resources such as Identity and Switch can cause serialization
- # to fail.
+ # This capturing logic interacts poorly with control flow contexts which
+ # want to replace inputs of ops far too late in the process. This can lead
+ # the context to get confused and try to create an Enter for an Enter. We
+ # can detect this here and skip the additional Enter which can confuse loop
+ # validation logic.
+ if op_type == "Enter" and inputs[0].op.type == "Enter":
+ if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
+ return inputs[0].op
+ # Calling AddValue on the control flow contexts to force creation of the
+ # backward accumulators in the original graph before we create placeholders
+ # to capture the inputs.
+ ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
- inputs[i] = self.capture(inp)
+ if ctxt is not None and hasattr(ctxt, "AddValue"):
+ inp = ctxt.AddValue(inp)
+ inp = self.capture(inp)
+ inputs[i] = inp
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device)
@@ -457,7 +469,6 @@ class GraphModeFunction(object):
self._func_name = name
self._function_def = defined_function
self._num_outputs = len(defined_function.signature.output_arg)
- self._ops = operations
self._python_func_outputs = python_func_outputs
self._python_returns = [python_func_outputs] if isinstance(
python_func_outputs,
@@ -500,9 +511,15 @@ class GraphModeFunction(object):
extra_placeholders = []
forward_name = _forward_name(self._func_name)
+ # Note: we cannot have placeholder ops in the graph or the TPU compilation
+ # pass fails.
+ placeholder_ops = set([y.op for y in self._input_placeholders])
+ function_ops = [x for x in self._graph.get_operations()
+ if x not in placeholder_ops]
self._forward_fdef = _EagerDefinedFunction(
- forward_name, self._graph, self._ops, self._input_placeholders,
- filtered_outputs + list(extra_inputs), self._attrs)
+ forward_name, self._graph, function_ops,
+ self._input_placeholders, filtered_outputs + list(extra_inputs),
+ self._attrs)
all_inputs = self._out_grad_placeholders + list(extra_placeholders)
# Excluding input ops from the body as we do not intend to execute these
# operations when the function is executed.
@@ -680,6 +697,11 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
func_args = _get_defun_inputs(args)
func_kwds = _get_defun_inputs(kwds)
+ # Variables to help check whether mutation happens in calling the function
+ # Copy the recursive list, tuple and map structure, but not base objects
+ func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
+ func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+
def convert(x):
if x is None:
return None
@@ -691,6 +713,25 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
try:
func_outputs = func(*func_args, **func_kwds)
func_outputs = nest.map_structure(convert, func_outputs)
+
+ def check_mutation(n1, n2):
+ """Check if two list of arguments are exactly the same."""
+ errmsg = ("Function to be traced should not modify structure of input "
+ "arguments. Check if your function has list and dictionary "
+ "operations that alter input arguments, "
+ "such as `list.pop`, `list.append`")
+ try:
+ nest.assert_same_structure(n1, n2)
+ except ValueError:
+ raise ValueError(errmsg)
+
+ for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)):
+ if arg1 is not arg2:
+ raise ValueError(errmsg)
+
+ check_mutation(func_args_before, func_args)
+ check_mutation(func_kwds_before, func_kwds)
+
finally:
tape.pop_tape(this_tape)
variables = this_tape.watched_variables()
@@ -894,8 +935,11 @@ def defun(func=None, compiled=False):
`defun`-generated graphs.
For a Python function to be compatible with `defun`, all of its arguments must
- be hashable Python objects or lists thereof. Additionally, it must return zero
- or more @{tf.Tensor} objects.
+ be hashable Python objects or lists thereof. The function itself may not
+ modify the list/map structure of its arguments. Additionally, it must return
+ zero or more @{tf.Tensor} objects. If the Python function returns
+ a @{tf.Variable}, its compiled version will return the value of that variable
+ as a @{tf.Tensor}.
Executing a graph generated by `defun` respects device annotations (i.e.,
all `with tf.device` directives present in a Python function will also be
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 2e86563a7d..5efdecdbc6 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import iterator_ops
@@ -226,6 +227,23 @@ class FunctionTest(test.TestCase):
y = f(x)
self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0)
+ @test_util.run_in_graph_and_eager_modes()
+ def testGraphLoopGradient(self):
+ if context.executing_eagerly():
+ self.skipTest('TODO(apassos): support loops in defuns in eager')
+
+ @function.defun
+ def f(x):
+ return control_flow_ops.while_loop(lambda _, i: i < 2,
+ lambda x, i: (2*x, i + 1),
+ [x, 0])[0]
+
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(1.0)
+ t.watch(x)
+ y = f(x)
+ self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0)
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -1212,6 +1230,174 @@ class AutomaticControlDependenciesTest(test.TestCase):
train()
self.assertEqual(v.numpy(), -1.0)
+ def testFunctionModifiesInputList(self):
+ # Tests on `list` methods that do in place modification, except `list.sort`
+ # since it cannot even be "defunned" in the first place
+
+ def get_list():
+ return [constant_op.constant(0.), constant_op.constant(1.)]
+
+ expected_msg = (
+ 'Function to be traced should not modify structure of input '
+ 'arguments. Check if your function has list and dictionary '
+ 'operations that alter input arguments, '
+ 'such as `list.pop`, `list.append`')
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def append(l):
+ l.append(constant_op.constant(0.))
+
+ append(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def extend(l):
+ l.extend([constant_op.constant(0.)])
+
+ extend(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def insert(l):
+ l.insert(0, constant_op.constant(0.))
+
+ insert(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def pop(l):
+ l.pop()
+
+ pop(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def reverse(l):
+ l.reverse()
+
+ reverse(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def remove(l):
+ l.remove(l[0])
+
+ remove(get_list())
+
+ # `list.clear` is a method that is in Py3 but not Py2
+ if sys.version.startswith('3'):
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def clear(l):
+ l.clear()
+
+ clear(get_list())
+
+ # One last test for keyword arguments
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def kwdappend(**kwargs):
+ l = kwargs['l']
+ l.append(constant_op.constant(0.))
+
+ kwdappend(l=get_list())
+
+ def testFunctionModifiesInputDict(self):
+
+ def get_dict():
+ return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
+
+ expected_msg = (
+ 'Function to be traced should not modify structure of input '
+ 'arguments. Check if your function has list and dictionary '
+ 'operations that alter input arguments, '
+ 'such as `list.pop`, `list.append`')
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def clear(m):
+ m.clear()
+
+ clear(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def pop(m):
+ m.pop('t1')
+
+ pop(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def popitem(m):
+ m.popitem()
+
+ popitem(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def update(m):
+ m.update({'t1': constant_op.constant(3.)})
+
+ update(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def setdefault(m):
+ m.setdefault('t3', constant_op.constant(3.))
+
+ setdefault(get_dict())
+
+ def testFunctionModifiesInputNest(self):
+ # Test on functions that modify structure of nested input arguments
+ expected_msg = (
+ 'Function to be traced should not modify structure of input '
+ 'arguments. Check if your function has list and dictionary '
+ 'operations that alter input arguments, '
+ 'such as `list.pop`, `list.append`')
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def modify(n):
+ n[0]['t1'].append(constant_op.constant(1.))
+
+ nested_input = [{
+ 't1': [constant_op.constant(0.),
+ constant_op.constant(1.)],
+ },
+ constant_op.constant(2.)]
+
+ modify(nested_input)
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ # The flat list doesn't change whereas the true structure changes
+ @function.defun
+ def modify_same_flat(n):
+ n[0].append(n[1].pop(0))
+
+ nested_input = [[constant_op.constant(0.)],
+ [constant_op.constant(1.),
+ constant_op.constant(2.)]]
+
+ modify_same_flat(nested_input)
+
if __name__ == '__main__':
ops.enable_eager_execution(
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 2dc5060984..9200396c8a 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -288,7 +288,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
with tmp_graph.as_default():
# Placeholders for the non-variable inputs.
func_inputs = _get_graph_callable_inputs(shape_and_dtypes)
- func_num_args = len(tf_inspect.getargspec(func).args)
+ func_num_args = len(tf_inspect.getfullargspec(func).args)
if len(func_inputs) != func_num_args:
raise TypeError("The number of arguments accepted by the decorated "
"function `%s` (%d) must match the number of "
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index ffdcc7c80a..43deb8bc6c 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -53,6 +53,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import evaluation
@@ -268,7 +269,7 @@ class Estimator(object):
found.
"""
with context.graph_mode():
- return saver.latest_checkpoint(self.model_dir)
+ return checkpoint_management.latest_checkpoint(self.model_dir)
def train(self,
input_fn,
@@ -417,7 +418,7 @@ class Estimator(object):
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.info('Could not find trained model in model_dir: {}, running '
'initialization to evaluate.'.format(self._model_dir))
@@ -504,7 +505,8 @@ class Estimator(object):
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
if not checkpoint_path:
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._model_dir)
if not checkpoint_path:
logging.info('Could not find trained model in model_dir: {}, running '
'initialization to predict.'.format(self._model_dir))
@@ -769,7 +771,8 @@ class Estimator(object):
with context.graph_mode():
if not checkpoint_path:
# Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._model_dir)
if not checkpoint_path:
raise ValueError("Couldn't find trained model at %s." % self._model_dir)
@@ -1258,28 +1261,24 @@ class Estimator(object):
training_chief_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_chief_hooks)
- # TODO(sourabhbajaj): Merge the two code paths once we can
- # handle per device variables correctly in reduce and can output
- # the loss scaler.
+ # TODO(sourabhbajaj): Merge the two code paths and clean up the code
if is_tpu_strategy:
- loss = self._train_distribution.unwrap(
- self._train_distribution.reduce(
- distribute_lib.get_loss_reduction(), tpu_result)[0])[0]
+ distributed_loss = tpu_result
worker_hooks.append(
estimator_util.StrategyInitFinalizeHook(
self._train_distribution.get_initialization_ops,
self._train_distribution.get_finalize_ops))
else:
- loss = self._train_distribution.unwrap(
- self._train_distribution.reduce(
- distribute_lib.get_loss_reduction(),
- grouped_estimator_spec.loss,
- destinations='/device:CPU:0'))[0]
+ distributed_loss = grouped_estimator_spec.loss
distributed_train_op = grouped_estimator_spec.train_op
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
- loss=loss,
+ loss=self._train_distribution.unwrap(
+ self._train_distribution.reduce(
+ distribute_lib.get_loss_reduction(),
+ distributed_loss,
+ destinations='/device:CPU:0'))[0],
train_op=self._train_distribution.group(distributed_train_op),
training_hooks=training_hooks,
training_chief_hooks=training_chief_hooks,
@@ -1630,7 +1629,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
def _check_checkpoint_available(model_dir):
- latest_path = saver.latest_checkpoint(model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(model_dir)
if not latest_path:
raise ValueError(
'Could not find trained model in model_dir: {}.'.format(model_dir))
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 68fc5bcadf..e8552092e0 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -69,6 +69,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import checkpoint_state_pb2
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
@@ -1548,7 +1549,8 @@ class EstimatorPredictTest(test.TestCase):
next(
est.predict(
dummy_input_fn,
- checkpoint_path=saver.latest_checkpoint('fakedir')))
+ checkpoint_path=
+ checkpoint_management.latest_checkpoint('fakedir')))
def test_tensor_predictions(self):
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index ca26341445..529e7a8b87 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -40,29 +40,38 @@ _SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
_SINGLE_LABEL_DEFAULT_NAME = 'label'
+_SINGLE_TENSOR_DEFAULT_NAMES = {
+ 'feature': _SINGLE_FEATURE_DEFAULT_NAME,
+ 'label': _SINGLE_LABEL_DEFAULT_NAME,
+ 'receiver_tensor': _SINGLE_RECEIVER_DEFAULT_NAME,
+ 'receiver_tensors_alternative': _SINGLE_RECEIVER_DEFAULT_NAME
+}
+
-def _wrap_and_check_receiver_tensors(receiver_tensors):
- """Ensure that receiver_tensors is a dict of str to Tensor mappings.
+def _wrap_and_check_input_tensors(tensors, field_name):
+ """Ensure that tensors is a dict of str to Tensor mappings.
Args:
- receiver_tensors: dict of str to Tensors, or a single Tensor.
+ tensors: dict of str to Tensors, or a single Tensor.
+ field_name: name of the member field of `ServingInputReceiver`
+ whose value is being passed to `tensors`.
Returns:
dict of str to Tensors; this is the original dict if one was passed, or
the original tensor wrapped in a dictionary.
Raises:
- ValueError: if receiver_tensors is None, or has non-string keys,
+ ValueError: if tensors is None, or has non-string keys,
or non-Tensor values
"""
- if receiver_tensors is None:
- raise ValueError('receiver_tensors must be defined.')
- if not isinstance(receiver_tensors, dict):
- receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}
- for name, tensor in receiver_tensors.items():
- _check_tensor_key(name, error_label='receiver_tensors')
- _check_tensor(tensor, name, error_label='receiver_tensor')
- return receiver_tensors
+ if tensors is None:
+ raise ValueError('{}s must be defined.'.format(field_name))
+ if not isinstance(tensors, dict):
+ tensors = {_SINGLE_TENSOR_DEFAULT_NAMES[field_name]: tensors}
+ for name, tensor in tensors.items():
+ _check_tensor_key(name, error_label=field_name)
+ _check_tensor(tensor, name, error_label=field_name)
+ return tensors
def _check_tensor(tensor, name, error_label='feature'):
@@ -125,15 +134,10 @@ class ServingInputReceiver(
features,
receiver_tensors,
receiver_tensors_alternatives=None):
- if features is None:
- raise ValueError('features must be defined.')
- if not isinstance(features, dict):
- features = {_SINGLE_FEATURE_DEFAULT_NAME: features}
- for name, tensor in features.items():
- _check_tensor_key(name)
- _check_tensor(tensor, name)
+ features = _wrap_and_check_input_tensors(features, 'feature')
- receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors)
+ receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors,
+ 'receiver_tensor')
if receiver_tensors_alternatives is not None:
if not isinstance(receiver_tensors_alternatives, dict):
@@ -142,17 +146,10 @@ class ServingInputReceiver(
receiver_tensors_alternatives))
for alternative_name, receiver_tensors_alt in (
six.iteritems(receiver_tensors_alternatives)):
- if not isinstance(receiver_tensors_alt, dict):
- receiver_tensors_alt = {
- _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt
- }
- # Updating dict during iteration is OK in this case.
- receiver_tensors_alternatives[alternative_name] = (
- receiver_tensors_alt)
- for name, tensor in receiver_tensors_alt.items():
- _check_tensor_key(name, error_label='receiver_tensors_alternative')
- _check_tensor(
- tensor, name, error_label='receiver_tensors_alternative')
+ # Updating dict during iteration is OK in this case.
+ receiver_tensors_alternatives[alternative_name] = (
+ _wrap_and_check_input_tensors(
+ receiver_tensors_alt, 'receiver_tensors_alternative'))
return super(ServingInputReceiver, cls).__new__(
cls,
@@ -245,16 +242,12 @@ class SupervisedInputReceiver(
def __new__(cls, features, labels, receiver_tensors):
# Both features and labels can be dicts or raw tensors.
for input_vals, error_label in ((features, 'feature'), (labels, 'label')):
- if input_vals is None:
- raise ValueError('{}s must be defined.'.format(error_label))
- if isinstance(input_vals, dict):
- for name, tensor in input_vals.items():
- _check_tensor_key(name, error_label=error_label)
- _check_tensor(tensor, name, error_label=error_label)
- else:
- _check_tensor(input_vals, None, error_label=error_label)
-
- receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors)
+ # _wrap_and_check_input_tensors is called here only to validate the
+ # tensors. The wrapped dict that is returned is deliberately discarded.
+ _wrap_and_check_input_tensors(input_vals, error_label)
+
+ receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors,
+ 'receiver_tensor')
return super(SupervisedInputReceiver, cls).__new__(
cls,
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index a7074712c2..d2ac7f0b3b 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -107,7 +107,7 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase):
receiver_tensors=None)
with self.assertRaisesRegexp(
- ValueError, "receiver_tensors keys must be strings"):
+ ValueError, "receiver_tensor keys must be strings"):
export.ServingInputReceiver(
features=features,
receiver_tensors={
@@ -271,7 +271,7 @@ class SupervisedInputReceiverTest(test_util.TensorFlowTestCase):
receiver_tensors=None)
with self.assertRaisesRegexp(
- ValueError, "receiver_tensors keys must be strings"):
+ ValueError, "receiver_tensor keys must be strings"):
export.SupervisedInputReceiver(
features=features,
labels=labels,
@@ -740,7 +740,7 @@ class TensorServingReceiverTest(test_util.TensorFlowTestCase):
receiver_tensors=None)
with self.assertRaisesRegexp(
- ValueError, "receiver_tensors keys must be strings"):
+ ValueError, "receiver_tensor keys must be strings"):
export.TensorServingInputReceiver(
features=features,
receiver_tensors={
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 079560c495..c63deb8f4d 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
@@ -442,7 +443,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
# save checkpoint into subdirectory to allow warm start
keras_model_dir = os.path.join(config.model_dir, 'keras')
# Load weights and save to checkpoint if there is no checkpoint
- latest_path = saver_lib.latest_checkpoint(keras_model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
if not latest_path:
keras_weights = None
if _any_weight_initialized(keras_model):
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index 7719d03019..6e844e14b9 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -87,17 +87,18 @@ def _parse_message(message):
return seps, tags
-def _compute_device_summary_from_list(device_assignment_list, prefix=""):
+def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
"""Return a summary of an op's device function stack.
Args:
+ name: The name of the op.
device_assignment_list: The op._device_assignments list.
prefix: An optional string prefix used before each line of the multi-
line string returned by this function.
Returns:
A multi-line string similar to:
- Device assignments active during op creation:
+ Device assignments active during op 'foo' creation:
with tf.device(/cpu:0): <test_1.py:27>
with tf.device(some_func<foo.py, 123>): <test_2.py:38>
The first line will have no padding to its left by default. Subsequent
@@ -105,11 +106,13 @@ def _compute_device_summary_from_list(device_assignment_list, prefix=""):
to increase indentation.
"""
if not device_assignment_list:
- message = "No device assignments were active during op creation."
+ message = "No device assignments were active during op '%s' creation."
+ message %= name
return prefix + message
str_list = []
- str_list.append("%sDevice assignments active during op creation:" % prefix)
+ str_list.append("%sDevice assignments active during op '%s' creation:"
+ % (prefix, name))
for traceable_obj in device_assignment_list:
location_summary = "<{file}:{line}>".format(file=traceable_obj.filename,
@@ -127,17 +130,17 @@ def _compute_device_summary_from_list(device_assignment_list, prefix=""):
def _compute_device_assignment_summary_from_op(op, prefix=""):
- if not op:
- return ""
# pylint: disable=protected-access
- return _compute_device_summary_from_list(op._device_assignments, prefix)
+ return _compute_device_summary_from_list(op.name, op._device_assignments,
+ prefix)
# pylint: enable=protected-access
-def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
+def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
"""Return a summary of an op's colocation stack.
Args:
+ name: The op name.
colocation_dict: The op._colocation_dict.
prefix: An optional string prefix used before each line of the multi-
line string returned by this function.
@@ -152,20 +155,21 @@ def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
to increase indentation.
"""
if not colocation_dict:
- message = "No node-device colocations were active during op creation."
+ message = "No node-device colocations were active during op '%s' creation."
+ message %= name
return prefix + message
str_list = []
- str_list.append("%sNode-device colocations active during op creation:"
- % prefix)
+ str_list.append("%sNode-device colocations active during op '%s' creation:"
+ % (prefix, name))
- for name, location in colocation_dict.items():
+ for coloc_name, location in colocation_dict.items():
location_summary = "<{file}:{line}>".format(file=location.filename,
line=location.lineno)
subs = {
"prefix": prefix,
"indent": " ",
- "name": name,
+ "name": coloc_name,
"loc": location_summary,
}
str_list.append(
@@ -176,11 +180,8 @@ def _compute_colocation_summary_from_dict(colocation_dict, prefix=""):
def _compute_colocation_summary_from_op(op, prefix=""):
"""Fetch colocation file, line, and nesting and return a summary string."""
- if not op:
- return ""
- # pylint: disable=protected-access
- return _compute_colocation_summary_from_dict(op._colocation_dict, prefix)
- # pylint: enable=protected-access
+ return _compute_colocation_summary_from_dict(
+ op.name, op._colocation_dict, prefix) # pylint: disable=protected-access
def _find_index_of_defining_frame_for_op(op):
@@ -216,16 +217,14 @@ def _find_index_of_defining_frame_for_op(op):
def _get_defining_frame_from_op(op):
"""Find and return stack frame where op was defined."""
- frame = None
- if op:
- # pylint: disable=protected-access
- frame_index = _find_index_of_defining_frame_for_op(op)
- frame = op._traceback[frame_index]
- # pylint: enable=protected-access
+ frame_index = _find_index_of_defining_frame_for_op(op)
+ # pylint: disable=protected-access
+ frame = op._traceback[frame_index]
+ # pylint: enable=protected-access
return frame
-def _compute_field_dict(op):
+def compute_field_dict(op):
"""Return a dictionary mapping interpolation tokens to values.
Args:
@@ -237,32 +236,40 @@ def _compute_field_dict(op):
{
"file": "tool_utils.py",
"line": "124",
+ "defined_at": " (defined at tool_utils.py:124)",
"colocations":
'''Node-device colocations active during op creation:
with tf.colocate_with(test_node_1): <test_1.py:27>
with tf.colocate_with(test_node_2): <test_2.py:38>'''
+ "devices":
+ '''Device assignments active during op 'foo' creation:
+ with tf.device(/cpu:0): <test_1.py:27>
+ with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
+ "devs_and_colocs": A concatenation of colocations and devices, e.g.
+ '''Node-device colocations active during op creation:
+ with tf.colocate_with(test_node_1): <test_1.py:27>
+ with tf.colocate_with(test_node_2): <test_2.py:38>'''
+ Device assignments active during op 'foo' creation:
+ with tf.device(/cpu:0): <test_1.py:27>
+ with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
}
- If op is None or lacks a _traceback field, the returned values will be
- "<NA>".
"""
- default_value = "<NA>"
- field_dict = {
- "file": default_value,
- "line": default_value,
- "colocations": default_value,
- "devices": default_value,
- }
frame = _get_defining_frame_from_op(op)
- if frame:
- field_dict["file"] = frame[tf_stack.TB_FILENAME]
- field_dict["line"] = frame[tf_stack.TB_LINENO]
+ filename = frame[tf_stack.TB_FILENAME]
+ lineno = frame[tf_stack.TB_LINENO]
+ defined_at = " (defined at %s:%d)" % (filename, lineno)
colocation_summary = _compute_colocation_summary_from_op(op)
- if colocation_summary:
- field_dict["colocations"] = colocation_summary
device_summary = _compute_device_assignment_summary_from_op(op)
- if device_summary:
- field_dict["devices"] = device_summary
+ combined_summary = "\n".join([colocation_summary, device_summary])
+ field_dict = {
+ "file": filename,
+ "line": lineno,
+ "defined_at": defined_at,
+ "colocations": colocation_summary,
+ "devices": device_summary,
+ "devs_and_colocs": combined_summary,
+ }
return field_dict
@@ -291,7 +298,12 @@ def interpolate(error_message, graph):
except KeyError:
op = None
- node_name_to_substitution_dict[name] = _compute_field_dict(op)
+ if op is not None:
+ field_dict = compute_field_dict(op)
+ else:
+ msg = "<NA>"
+ field_dict = collections.defaultdict(lambda s=msg: s)
+ node_name_to_substitution_dict[name] = field_dict
subs = [
string.Template(tag.format).safe_substitute(
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index fbf182879b..0427156b2b 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -71,8 +71,9 @@ class ComputeDeviceSummaryFromOpTest(test.TestCase):
lineno=42))
summary = error_interpolation._compute_device_summary_from_list(
- assignments, prefix=" ")
+ "nodename", assignments, prefix=" ")
+ self.assertIn("nodename", summary)
self.assertIn("tf.device(/cpu:0)", summary)
self.assertIn("<hope.py:24>", summary)
self.assertIn("tf.device(/gpu:2)", summary)
@@ -81,7 +82,8 @@ class ComputeDeviceSummaryFromOpTest(test.TestCase):
def testCorrectFormatWhenNoColocationsWereActive(self):
device_assignment_list = []
summary = error_interpolation._compute_device_summary_from_list(
- device_assignment_list, prefix=" ")
+ "nodename", device_assignment_list, prefix=" ")
+ self.assertIn("nodename", summary)
self.assertIn("No device assignments", summary)
@@ -99,7 +101,8 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
"test_node_2": t_obj_2,
}
summary = error_interpolation._compute_colocation_summary_from_dict(
- colocation_dict, prefix=" ")
+ "node_name", colocation_dict, prefix=" ")
+ self.assertIn("node_name", summary)
self.assertIn("colocate_with(test_node_1)", summary)
self.assertIn("<test_1.py:27>", summary)
self.assertIn("colocate_with(test_node_2)", summary)
@@ -108,7 +111,8 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
def testCorrectFormatWhenNoColocationsWereActive(self):
colocation_dict = {}
summary = error_interpolation._compute_colocation_summary_from_dict(
- colocation_dict, prefix=" ")
+ "node_name", colocation_dict, prefix=" ")
+ self.assertIn("node_name", summary)
self.assertIn("No node-device colocations", summary)
@@ -176,7 +180,7 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
one_tag_string = "^^node:MinusOne:${file}^^"
interpolated_string = error_interpolation.interpolate(one_tag_string,
self.graph)
- self.assertEqual(interpolated_string, "<NA>")
+ self.assertEqual("<NA>", interpolated_string)
def testTwoTagsNoSeps(self):
two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^"
@@ -287,7 +291,6 @@ class InterpolateColocationSummaryTest(test.TestCase):
message = "^^node:One:${colocations}^^"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
- self.assertNotIn("One", result)
self.assertNotIn("Two", result)
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index c76743d2c6..12bf03c5fa 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -819,7 +819,7 @@ class _FuncGraph(ops.Graph):
def func_graph_from_py_func(func, arg_names, arg_types, name=None,
capture_by_value=False, device=None,
colocation_stack=None, container=None,
- collections_ref=None):
+ collections_ref=None, arg_shapes=None):
"""Returns a _FuncGraph generated from `func`.
Args:
@@ -836,6 +836,7 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
container: A container name the _FuncGraph should start with.
collections_ref: A reference to a collections dict the _FuncGraph should
use internally.
+ arg_shapes: A sequence of the function's argument shapes.
Returns:
A _FuncGraph.
@@ -857,9 +858,12 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
func_graph._colocation_stack = colocation_stack
# pylint: enable=protected-access
+ if arg_shapes is None:
+ arg_shapes = [None] * len(arg_types)
+
# Create placeholders for the function arguments.
- for (argname, argtype) in zip(arg_names, arg_types):
- argholder = array_ops.placeholder(argtype, name=argname)
+ for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
+ argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
func_graph.inputs.append(argholder)
# Call func and gather the output tensors.
with vs.variable_scope("", custom_getter=func_graph.getvar):
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index c25e29b0f4..ed0bf1afe0 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -44,6 +44,7 @@ from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import registry
@@ -454,7 +455,7 @@ class Tensor(_TensorLike):
def __iter__(self):
if not context.executing_eagerly():
raise TypeError(
- "Tensor objects are not iterable when eager execution is not "
+ "Tensor objects are only iterable when eager execution is "
"enabled. To iterate over this tensor use tf.map_fn.")
shape = self._shape_tuple()
if shape is None:
@@ -3292,6 +3293,36 @@ class Graph(object):
self._create_op_helper(ret, compute_device=compute_device)
return ret
+ def _make_colocation_conflict_message(self, op, colocation_op):
+ """Return detailed error message about device conflict due to colocation."""
+ # Example error message:
+ # Tried to colocate op 'a' (defined at file1.py:149) having device
+ # '/device:GPU:0' with op 'b' (defined at file2:96) which had an
+ # incompatible device '/device:CPU:0'.
+ #
+ # No node-device colocations were active during op 'a' creation.
+ # Device assignments active during op 'a' creation:
+ # with tf.device(/device:GPU:0): file1.py:148>
+ #
+ # Node-device colocations active during op 'b' creation:
+ # with tf.colocate_with(a): file2.py:93>
+ # Device assignments active during op 'b' creation:
+ # with tf.device(/cpu:0): file2.py:94
+ op_info = error_interpolation.compute_field_dict(op)
+ coloc_op_info = error_interpolation.compute_field_dict(colocation_op)
+ msg = ("Tried to colocate op '{op_name}'{op_loc} having device '{op_dev}' "
+ "with op '{coloc_op_name}'{coloc_op_loc} which had an incompatible "
+ "device '{coloc_op_dev}'.\n\n{op_summary}\n\n{coloc_op_summary}"
+ .format(op_name=op.name,
+ op_loc=op_info["defined_at"],
+ op_dev=op.device,
+ op_summary=op_info["devs_and_colocs"],
+ coloc_op_name=colocation_op.name,
+ coloc_op_loc=coloc_op_info["defined_at"],
+ coloc_op_dev=colocation_op.device,
+ coloc_op_summary=coloc_op_info["devs_and_colocs"]))
+ return msg
+
def _create_op_helper(self, op, compute_device=True):
"""Common logic for creating an op in this graph."""
# Apply any additional attributes requested. Do not overwrite any existing
@@ -3332,20 +3363,22 @@ class Graph(object):
if compute_device:
self._apply_device_functions(op)
+ # Snapshot the colocation stack metadata before we might generate error
+ # messages using it. Note that this snapshot depends on the actual stack
+ # and is independent of the op's _class attribute.
+ # pylint: disable=protected-access
+ op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
+ # pylint: enable=protected-access
+
if self._colocation_stack:
all_colocation_groups = []
for colocation_op in self._colocation_stack.peek_objs():
all_colocation_groups.extend(colocation_op.colocation_groups())
if colocation_op.device:
- # Make this device match the device of the colocated op, to provide
- # consistency between the device and the colocation property.
if (op.device and pydev.canonical_name(op.device) !=
pydev.canonical_name(colocation_op.device)):
- logging.warning("Tried to colocate %s with an op %s that had "
- "a different device: %s vs %s. Postponing "
- "error-checking until all devices are assigned.",
- op.name, colocation_op.name, op.device,
- colocation_op.device)
+ msg = self._make_colocation_conflict_message(op, colocation_op)
+ logging.warning(msg)
else:
op._set_device(colocation_op.device) # pylint: disable=protected-access
@@ -3353,7 +3386,6 @@ class Graph(object):
# pylint: disable=protected-access
op._set_attr("_class", attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
- op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
# pylint: enable=protected-access
# Sets "container" attribute if
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 48328a7f58..318387c61b 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2728,6 +2728,28 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
self.assertEqual("/device:CPU:0", b.device)
+ def testMakeColocationConflictMessage(self):
+ """Test that provides an example of a complicated error message."""
+ # We could test the message with any ops, but this test will be more
+ # instructive with a real colocation conflict.
+ with ops.device("/device:GPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ with ops.colocate_with(a.op):
+ with ops.device("/cpu:0"):
+ b = constant_op.constant([3.0], name="b")
+ # The definition-location of the nodes will be wrong because of running
+ # from within a TF unittest. The rest of the info should be correct.
+ message = ops.get_default_graph()._make_colocation_conflict_message(a.op,
+ b.op)
+ self.assertRegexpMatches(message,
+ r"Tried to colocate op 'a' \(defined at.*\)")
+ self.assertRegexpMatches(message, "No node-device.*'a'")
+ self.assertRegexpMatches(message, "Device assignments active.*'a'")
+ self.assertRegexpMatches(message, "GPU:0")
+ self.assertRegexpMatches(message, "Node-device colocations active.*'b'")
+ self.assertRegexpMatches(message, "Device assignments active.*'b'")
+ self.assertRegexpMatches(message, "cpu:0")
+
class DeprecatedTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index fc47b1cca5..764e8bfacb 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -51,7 +51,6 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
-from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape # pylint: disable=unused-import
from tensorflow.python.framework import device as pydev
@@ -498,9 +497,7 @@ def assert_no_new_tensors(f):
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
- backprop._zeros_cache.flush()
- context.get_default_context().ones_rank_cache().flush()
- context.get_default_context().scalar_cache().clear()
+ context.get_default_context()._clear_caches() # pylint: disable=protected-access
gc.collect()
tensors_after = [
obj for obj in gc.get_objects()
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 0616b29494..1706158c65 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -114,12 +114,14 @@ py_library(
"constraints.py",
"engine/__init__.py",
"engine/base_layer.py",
+ "engine/distributed_training_utils.py",
"engine/input_layer.py",
"engine/network.py",
"engine/saving.py",
"engine/sequential.py",
"engine/training.py",
"engine/training_arrays.py",
+ "engine/training_distributed.py",
"engine/training_eager.py",
"engine/training_generator.py",
"engine/training_utils.py",
@@ -870,7 +872,7 @@ py_test(
py_test(
name = "models_test",
- size = "small",
+ size = "medium",
srcs = ["models_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"], # b/67509773
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 38794f1612..418586b85f 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -648,7 +648,7 @@ def variable(value, dtype=None, name=None, constraint=None):
constraint=constraint)
if isinstance(value, np.ndarray):
v._keras_shape = value.shape
- elif hasattr(value, 'get_shape'):
+ elif hasattr(value, 'shape'):
v._keras_shape = int_shape(value)
v._uses_learning_phase = False
return v
@@ -736,9 +736,10 @@ def is_keras_tensor(x):
True
```
"""
- if not isinstance(x, (ops.Tensor,
- variables_module.Variable,
- sparse_tensor.SparseTensor)):
+ if (not isinstance(x, (ops.Tensor,
+ variables_module.Variable,
+ sparse_tensor.SparseTensor)) and
+ x.__class__.__name__ != 'DeferredTensor'):
raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
'`. Expected a symbolic tensor instance.')
return hasattr(x, '_keras_history')
@@ -853,7 +854,10 @@ def int_shape(x):
```
"""
try:
- return tuple(x.get_shape().as_list())
+ shape = x.shape
+ if not isinstance(shape, tuple):
+ shape = tuple(shape.as_list())
+ return shape
except ValueError:
return None
@@ -880,7 +884,7 @@ def ndim(x):
2
```
"""
- dims = x.get_shape()._dims
+ dims = x.shape._dims
if dims is not None:
return len(dims)
return None
@@ -968,7 +972,7 @@ def zeros(shape, dtype=None, name=None):
dtype = floatx()
tf_dtype = dtypes_module.as_dtype(dtype)
v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
+ if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
return v
@@ -1002,7 +1006,7 @@ def ones(shape, dtype=None, name=None):
dtype = floatx()
tf_dtype = dtypes_module.as_dtype(dtype)
v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
+ if py_all(v.shape.as_list()):
return variable(v, dtype=dtype, name=name)
return v
@@ -1196,7 +1200,7 @@ def count_params(x):
[ 0., 0., 0.]], dtype=float32)
```
"""
- return np.prod(x.get_shape().as_list())
+ return np.prod(x.shape.as_list())
@tf_export('keras.backend.cast')
@@ -2115,10 +2119,10 @@ def _fused_normalize_batch_in_training(x,
if gamma is None:
gamma = constant_op.constant(
- 1.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]])
+ 1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
if beta is None:
beta = constant_op.constant(
- 0.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]])
+ 0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
return nn.fused_batch_norm(
x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
@@ -2323,7 +2327,7 @@ def repeat_elements(x, rep, axis):
Returns:
A tensor.
"""
- x_shape = x.get_shape().as_list()
+ x_shape = x.shape.as_list()
# For static axis
if x_shape[axis] is not None:
# slices along the repeat axis
@@ -2343,7 +2347,7 @@ def repeat_elements(x, rep, axis):
auxiliary_axis = axis + 1
x_shape = array_ops.shape(x)
x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
- reps = np.ones(len(x.get_shape()) + 1)
+ reps = np.ones(len(x.shape) + 1)
reps[auxiliary_axis] = rep
x_rep = array_ops.tile(x_rep, reps)
@@ -2355,7 +2359,7 @@ def repeat_elements(x, rep, axis):
x_rep = array_ops.reshape(x_rep, x_shape)
# Fix shape representation
- x_shape = x.get_shape().as_list()
+ x_shape = x.shape.as_list()
x_rep.set_shape(x_shape)
x_rep._keras_shape = tuple(x_shape)
return x_rep
@@ -2934,8 +2938,8 @@ def function(inputs, outputs, updates=None, **kwargs):
"""
if kwargs:
for key in kwargs:
- if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
- key not in tf_inspect.getargspec(Function.__init__)[0]):
+ if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
+ and key not in tf_inspect.getfullargspec(Function.__init__)[0]):
msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
'backend') % key
raise ValueError(msg)
@@ -3032,17 +3036,17 @@ def rnn(step_function,
ValueError: if `mask` is provided (not `None`) but states is not provided
(`len(states)` == 0).
"""
- ndim = len(inputs.get_shape())
+ ndim = len(inputs.shape)
if ndim < 3:
raise ValueError('Input should be at least 3D.')
- inputs_shape = inputs.get_shape()
+ inputs_shape = inputs.shape
axes = [1, 0] + list(range(2, ndim))
inputs = array_ops.transpose(inputs, (axes))
if mask is not None:
if mask.dtype != dtypes_module.bool:
mask = math_ops.cast(mask, dtypes_module.bool)
- if len(mask.get_shape()) == ndim - 1:
+ if len(mask.shape) == ndim - 1:
mask = expand_dims(mask)
mask = array_ops.transpose(mask, axes)
@@ -3053,7 +3057,7 @@ def rnn(step_function,
uses_learning_phase = False
if unroll:
- if not inputs.get_shape()[0]:
+ if not inputs.shape[0]:
raise ValueError('Unrolling requires a fixed number of timesteps.')
states = initial_states
successive_states = []
@@ -3170,7 +3174,7 @@ def rnn(step_function,
global uses_learning_phase # pylint: disable=global-variable-undefined
uses_learning_phase = True
for state, new_state in zip(states, new_states):
- new_state.set_shape(state.get_shape())
+ new_state.set_shape(state.shape)
tiled_mask_t = array_ops.tile(mask_t,
array_ops.stack(
[1, array_ops.shape(output)[1]]))
@@ -3207,7 +3211,7 @@ def rnn(step_function,
global uses_learning_phase # pylint: disable=global-variable-undefined
uses_learning_phase = True
for state, new_state in zip(states, new_states):
- new_state.set_shape(state.get_shape())
+ new_state.set_shape(state.shape)
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
@@ -3225,11 +3229,11 @@ def rnn(step_function,
outputs = output_ta.stack()
last_output = output_ta.read(last_time - 1)
- axes = [1, 0] + list(range(2, len(outputs.get_shape())))
+ axes = [1, 0] + list(range(2, len(outputs.shape)))
outputs = array_ops.transpose(outputs, axes)
# Static shape inference: (samples, time, ...)
- outputs_shape = outputs.get_shape().as_list()
+ outputs_shape = outputs.shape.as_list()
outputs_shape[0] = inputs_shape[0]
outputs_shape[1] = inputs_shape[1]
outputs.set_shape(outputs_shape)
@@ -3500,7 +3504,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
Raises:
ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
- rank = len(output.get_shape())
+ rank = len(output.shape)
axis = axis % rank
# Note: nn.softmax_cross_entropy_with_logits_v2
# expects logits, Keras expects probabilities.
@@ -3536,7 +3540,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
Raises:
ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
- rank = len(output.get_shape())
+ rank = len(output.shape)
axis = axis % rank
if axis != rank - 1:
permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis]
@@ -3549,7 +3553,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
output = math_ops.log(output)
- output_shape = output.get_shape()
+ output_shape = output.shape
targets = cast(flatten(target), 'int64')
logits = array_ops.reshape(output, [-1, int(output_shape[-1])])
res = nn.sparse_softmax_cross_entropy_with_logits(
@@ -3796,7 +3800,7 @@ def conv1d(x,
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
- kernel_shape = kernel.get_shape().as_list()
+ kernel_shape = kernel.shape.as_list()
if padding == 'causal':
# causal (dilated) convolution:
left_pad = dilation_rate * (kernel_shape[0] - 1)
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index e1214f8103..33ad155072 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -175,7 +175,7 @@ class Layer(checkpointable.CheckpointableBase):
self.supports_masking = False
- call_argspec = tf_inspect.getargspec(self.call)
+ call_argspec = tf_inspect.getfullargspec(self.call)
if 'training' in call_argspec.args:
self._expects_training_arg = True
else:
@@ -735,9 +735,11 @@ class Layer(checkpointable.CheckpointableBase):
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
if (not hasattr(self, '_is_graph_network') or
- self.__class__.__name__ == 'Sequential'):
- # Only if self is a layer or an instance of a sequential model do we
- # need to build it.
+ self.__class__.__name__ == 'Sequential' or
+ not hasattr(self.build, '_is_default')):
+ # Only if self is a layer, an instance of a sequential model, or
+ # the user has manually overwritten the build method do we need to
+ # build it.
self.build(input_shapes)
# We must set self.built since user defined build functions are not
# constrained to set self.built.
@@ -771,7 +773,6 @@ class Layer(checkpointable.CheckpointableBase):
if build_graph:
self._handle_activity_regularization(inputs, outputs)
- # TODO(fchollet): consider enabling masking for Eager mode.
self._set_mask_metadata(inputs, outputs, previous_mask)
if in_deferred_mode or build_graph and have_all_keras_metadata(inputs):
@@ -828,21 +829,27 @@ class Layer(checkpointable.CheckpointableBase):
pass
def _set_mask_metadata(self, inputs, outputs, previous_mask):
- if hasattr(self, 'compute_mask'):
+ # In some cases the mask of the outputs has already been computed by
+ # inner layers and does not need to be recomputed by this layer.
+ mask_already_computed = all(
+ hasattr(x, '_keras_mask') for x in generic_utils.to_list(outputs))
+ if hasattr(self, 'compute_mask') and not mask_already_computed:
output_mask = self.compute_mask(inputs, previous_mask)
- if isinstance(outputs, (list, tuple)):
- if output_mask is None:
- output_mask = [None for _ in range(len(outputs))]
- for x, m in zip(outputs, output_mask):
- try:
- x._keras_mask = m # pylint: disable=protected-access
- except AttributeError:
- pass # C type such as dict. Masking not supported in this case.
- else:
+ else:
+ output_mask = None
+ if isinstance(outputs, (list, tuple)):
+ if output_mask is None:
+ output_mask = [None for _ in range(len(outputs))]
+ for x, m in zip(outputs, output_mask):
try:
- outputs._keras_mask = output_mask # pylint: disable=protected-access
+ x._keras_mask = m # pylint: disable=protected-access
except AttributeError:
pass # C type such as dict. Masking not supported in this case.
+ else:
+ try:
+ outputs._keras_mask = output_mask # pylint: disable=protected-access
+ except AttributeError:
+ pass # C type such as dict. Masking not supported in this case.
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
call_convention = getattr(self, '_call_convention',
@@ -904,7 +911,7 @@ class Layer(checkpointable.CheckpointableBase):
assert len(call_args) == 1 # TypeError raised earlier in __call__.
return call_args[0], call_kwargs
else:
- call_arg_spec = tf_inspect.getargspec(self.call)
+ call_arg_spec = tf_inspect.getfullargspec(self.call)
# There is no explicit "inputs" argument expected or provided to
# call(). Arguments which have default values are considered non-inputs,
# and arguments without are considered inputs.
@@ -924,8 +931,8 @@ class Layer(checkpointable.CheckpointableBase):
_, unwrapped_call = tf_decorator.unwrap(self.call)
bound_args = inspect.getcallargs(
unwrapped_call, *call_args, **call_kwargs)
- if call_arg_spec.keywords is not None:
- var_kwargs = bound_args.pop(call_arg_spec.keywords)
+ if call_arg_spec.varkw is not None:
+ var_kwargs = bound_args.pop(call_arg_spec.varkw)
bound_args.update(var_kwargs)
keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
all_args = call_arg_spec.args
@@ -1958,15 +1965,10 @@ def make_variable(name,
return v
-def generate_dummy_data_from_shape(shape):
- if isinstance(shape, tensor_shape.TensorShape):
- shape = shape.as_list()
-
- # Replace Nones in input shape with dummy `1` value
- shape = [x.value if isinstance(x, tensor_shape.Dimension) else x
- for x in shape]
- shape = [1 if x is None else x for x in shape]
- return array_ops.ones(shape, dtype=backend.floatx())
+def default(method):
+ """Decorates a method to detect overrides in subclasses."""
+ method._is_default = True
+ return method
def generate_placeholders_from_shape(shape):
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
new file mode 100644
index 0000000000..c78e6fe9ec
--- /dev/null
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -0,0 +1,249 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities related to distributed training."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import callbacks
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import nest
+
+
+def set_weights(distribution_strategy, dist_model, weights):
+ """Sets the weights of the replicated models.
+
+ The weights of the replicated models are set to the weights of the original
+ model. The weights of the replicated model are Mirrored variables and hence
+ we need to use the `update` call within a DistributionStrategy scope.
+
+ Args:
+ distribution_strategy: DistributionStrategy used to distribute training
+ and validation.
+ dist_model: The replicated models on the different devices.
+ weights: The weights of the original model.
+ """
+ assign_ops = []
+ for layer in dist_model.layers:
+ num_param = len(layer.weights)
+ layer_weights = weights[:num_param]
+ for sw, w in zip(layer.weights, layer_weights):
+ assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
+
+ weights = weights[num_param:]
+ backend.get_session().run(assign_ops)
+
+
+def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
+ grouped_updates, grouped_session_args,
+ with_loss_tensor=False):
+ """Unwrap and return the list of values contained in the PerDevice parameters.
+
+ This function calls `flatten_perdevice_values` to parse each of the input
+ parameters into a list of values on the different devices. If we set
+ `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
+ the different devices to give us one loss tensor.
+
+ Args:
+ distribution_strategy: DistributionStrategy used to distribute training and
+ validation.
+ grouped_inputs: PerDevice inputs returned from the train or test function
+ that we ran on each device.
+ grouped_outputs: PerDevice outputs returned from the train or test function
+ that we ran on each device.
+ grouped_updates: PerDevice updates returned from the train or test function
+ that we ran on each device.
+ grouped_session_args: PerDevice session args returned from the train or
+ test function that we ran on each device.
+ with_loss_tensor: Boolean that indicates if we need to add the reduced loss
+ tensor as one of the outputs.
+
+ Returns:
+ Values of each of the PerDevice parameters.
+
+ """
+ # Unwrap per device values returned from each model's train function.
+ # This will be used to construct the main train function.
+ all_inputs = flatten_perdevice_values(distribution_strategy,
+ grouped_inputs)
+ if with_loss_tensor:
+ # reduce loss tensor before adding it to the list of fetches
+ loss = distribution_strategy.unwrap(
+ distribution_strategy.reduce(distribute_lib.get_loss_reduction(),
+ grouped_outputs[0],
+ destinations='/device:CPU:0'))[0]
+
+ all_outputs = flatten_perdevice_values(distribution_strategy,
+ grouped_outputs[1:])
+ all_outputs = [loss] + all_outputs
+ else:
+ all_outputs = flatten_perdevice_values(distribution_strategy,
+ grouped_outputs)
+
+ all_updates = flatten_perdevice_values(distribution_strategy,
+ grouped_updates)
+
+ all_session_args = {}
+ grouped_feed_dict = grouped_session_args.get('feed_dict')
+ if grouped_feed_dict:
+ all_session_args['feed_dict'] = flatten_perdevice_values(
+ distribution_strategy, grouped_feed_dict)
+
+ grouped_fetches = grouped_session_args.get('fetches')
+ if grouped_fetches:
+ all_session_args['fetches'] = flatten_perdevice_values(
+ distribution_strategy, grouped_fetches)
+
+ return all_inputs, all_outputs, all_updates, all_session_args
+
+
+def flatten_perdevice_values(distribution_strategy, perdevice_values):
+ """Unwraps and flattens a nest of PerDevice parameters.
+
+ PerDevice values have one value associated with each device. Each entry in
+ the PerDevice dict has a device `key` and the corresponding value on the
+ device as the `value`. In this function we take a PerDevice value or a list of
+ PerDevice values and return all the values in the PerDevice dict.
+
+ Args:
+ distribution_strategy: DistributionStrategy used to distribute training and
+ validation.
+ perdevice_values: List of PerDevice object or a single PerDevice object.
+
+ Returns:
+ List of values of all the PerDevice objects.
+
+ """
+ # This function takes a PerDevice object or a list of PerDevice objects and
+ # returns all the values associated with it.
+ return [e for flattened in nest.flatten(perdevice_values)
+ for e in distribution_strategy.unwrap(flattened)]
+
+
+def validate_callbacks(input_callbacks):
+ """Validate whether given callbacks are supported by DistributionStrategy.
+
+ Args:
+ input_callbacks: List of callbacks passed by the user to fit.
+
+ Raises:
+ ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
+ callbacks passed.
+ ValueError: If `histogram_freq` or `write_grads` is one of the parameters
+ passed as part of the TensorBoard callback.
+ """
+ if input_callbacks:
+ for callback in input_callbacks:
+ if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
+ callbacks.LearningRateScheduler, callbacks.CSVLogger,
+ callbacks.EarlyStopping, callbacks.ModelCheckpoint,
+ callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
+ callbacks.History, callbacks.RemoteMonitor]:
+ logging.warning('Your input callback is not one of the predefined '
+ 'Callbacks that supports DistributionStrategy. You '
+ 'might encounter an error if you access one of the '
+ 'model\'s attributes as part of the callback since '
+ 'these attributes are not set. You can access each of '
+ 'the individual distributed models using the '
+ '`_grouped_model` attribute of your original model.')
+ if isinstance(callback, callbacks.LearningRateScheduler):
+ raise ValueError('LearningRateScheduler callback is not supported with '
+ 'DistributionStrategy.')
+ if isinstance(callback, callbacks.ReduceLROnPlateau):
+ raise ValueError('ReduceLROnPlateau callback is not supported with '
+ 'DistributionStrategy.')
+
+ # If users want to use the TensorBoard callback they cannot use certain
+ # features of the callback that involve accessing model attributes and
+ # running ops.
+ if isinstance(callback, callbacks.TensorBoard):
+ if callback.__getattribute__('histogram_freq'):
+ raise ValueError('histogram_freq in the TensorBoard callback is not '
+ 'supported when using DistributionStrategy.')
+ if callback.__getattribute__('write_grads'):
+ raise ValueError('write_grads in the TensorBoard callback is not '
+ 'supported when using DistributionStrategy.')
+
+
+def validate_distributed_dataset_inputs(distribution_strategy, x, y):
+ """Validate all the components of a DistributedValue Dataset input.
+
+ Args:
+ distribution_strategy: The current DistributionStrategy using to call
+ `fit`/`evaluate`.
+ x: Input Dataset DistributedValue object. For example, when we use
+ `MirroredStrategy` this is a PerDevice object with a tensor for each
+ device set in the dict.
+ y: Target Dataset DistributedValue object. For example, when we use
+ `MirroredStrategy` this is a PerDevice object with a tensor for each
+ device set in the dict.
+
+ Returns:
+ The unwrapped values list of the x and y DistributedValues inputs.
+
+ Raises:
+ ValueError: If x and y do not have support for being evaluated as tensors.
+ or if x and y contain elements that are not tensors or if x and y
+ contain elements that have a shape or dtype mismatch.
+ """
+ # If the input and target used to call the model are not dataset tensors,
+ # we need to raise an error. When using a DistributionStrategy, the input
+ # and targets to a model should be from a `tf.data.Dataset`.
+
+ # If each element of x and y are not tensors, we cannot standardize and
+ # validate the input and targets.`
+ if not tensor_util.is_tensor(x):
+ raise ValueError('Dataset input to the model should be tensors instead they'
+ ' are of type {}'.format(type(x)))
+
+ if not tensor_util.is_tensor(y):
+ raise ValueError('Dataset input to the model should be tensors instead they'
+ ' are of type {}'.format(type(y)))
+
+ # At this point both x and y contain tensors in the `DistributedValues`
+ # structure.
+ x_values = distribution_strategy.unwrap(x)
+ y_values = distribution_strategy.unwrap(y)
+
+ # Validate that the shape and dtype of all the elements in x are the same.
+ validate_all_tensor_shapes(x, x_values)
+ validate_all_tensor_types(x, x_values)
+
+ # Similarly for y, we perform the same validation
+ validate_all_tensor_shapes(y, y_values)
+ validate_all_tensor_types(y, y_values)
+
+ # Return the unwrapped values to avoid calling `unwrap` a second time.
+ return x_values, y_values
+
+
+def validate_all_tensor_types(x, x_values):
+ x_dtype = x_values[0].dtype
+ for i in range(1, len(x_values)):
+ if x_dtype != x_values[i].dtype:
+ raise ValueError('Input tensor dtypes do not match for distributed tensor'
+ ' inputs {}'.format(x))
+
+
+def validate_all_tensor_shapes(x, x_values):
+ # Validate that the shape of all the elements in x have the same shape
+ x_shape = x_values[0].get_shape().as_list()
+ for i in range(1, len(x_values)):
+ if x_shape != x_values[i].get_shape().as_list():
+ raise ValueError('Input tensor shapes do not match for distributed tensor'
+ ' inputs {}'.format(x))
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 20a29dbf20..8f35794456 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -29,6 +29,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@@ -46,7 +47,6 @@ from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
from tensorflow.python.training.checkpointable import util as checkpointable_utils
-from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -214,7 +214,7 @@ class Network(base_layer.Layer):
self._base_init(name=name)
self._compute_previous_mask = (
- 'mask' in tf_inspect.getargspec(self.call).args or
+ 'mask' in tf_inspect.getfullargspec(self.call).args or
hasattr(self, 'compute_mask'))
# A Network does not create weights of its own, thus it is already
# built.
@@ -270,23 +270,6 @@ class Network(base_layer.Layer):
input_tensors=self.inputs,
output_tensors=self.outputs)
- # Fill in the output mask cache.
- masks = []
- for x in self.inputs:
- mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access
- masks.append(mask)
- mask_cache_key = (generic_utils.object_list_uid(self.inputs) + '_' +
- generic_utils.object_list_uid(masks))
- masks = []
- for x in self.outputs:
- mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access
- masks.append(mask)
- if len(masks) == 1:
- mask = masks[0]
- else:
- mask = masks
- self._output_mask_cache[mask_cache_key] = mask
-
# Build self.input_names and self.output_names.
self.input_names = []
self.output_names = []
@@ -308,7 +291,7 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
- call_argspec = tf_inspect.getargspec(self.call)
+ call_argspec = tf_inspect.getfullargspec(self.call)
if 'training' in call_argspec.args:
self._expects_training_arg = True
else:
@@ -512,13 +495,9 @@ class Network(base_layer.Layer):
masks = [None for _ in range(len(inputs))]
else:
masks = generic_utils.to_list(mask)
- cache_key = (generic_utils.object_list_uid(inputs)
- + '_' + generic_utils.object_list_uid(masks))
- if cache_key in self._output_mask_cache:
- return self._output_mask_cache[cache_key]
- else:
- _, output_masks = self._run_internal_graph(inputs, mask=masks)
- return output_masks
+
+ _, output_masks = self._run_internal_graph(inputs, mask=masks)
+ return output_masks
@property
def layers(self):
@@ -735,6 +714,7 @@ class Network(base_layer.Layer):
return specs[0]
return specs
+ @base_layer.default
def build(self, input_shape):
"""Builds the model based on input shapes received.
@@ -773,35 +753,41 @@ class Network(base_layer.Layer):
'input type: {}'.format(type(input_shape)))
if input_shape and not self.inputs:
- if isinstance(input_shape, list):
- # List of input shapes
- x = [base_layer.generate_dummy_data_from_shape(shape)
- for shape in input_shape]
- else:
- x = base_layer.generate_dummy_data_from_shape(input_shape)
-
- kwargs = {}
- num_call_args = len(tf_inspect.getargspec(self.call).args)
- if self._expects_training_arg and num_call_args == 3:
- # Has call signature of call(self, input, training)
- kwargs['training'] = False
- elif num_call_args > 2:
- # Has invalid call signature of call(self, input, *args, **kwargs)
- raise ValueError('Currently, you cannot build your model if it has '
- 'positional or keyword arguments that are not '
- 'inputs to the model, but are required for its '
- '`call` method. Instead, in order to instantiate '
- 'and build your model, `call` your model on real '
- 'tensor data with all expected call arguments.')
+ # We create placeholders for the `None`s in the shape and build the model
+ # in a Graph. Since tf.Variable is compatible with both eager execution
+ # and graph building, the variables created after building the model in
+ # a Graph are still valid when executing eagerly.
+ with context.graph_mode():
+ graph = eager_function.CapturingGraph()
+ with graph.as_default():
+ if isinstance(input_shape, list):
+ x = [base_layer.generate_placeholders_from_shape(shape)
+ for shape in input_shape]
+ else:
+ x = base_layer.generate_placeholders_from_shape(input_shape)
- try:
- self.call(x, **kwargs)
- except (errors.InvalidArgumentError, TypeError):
- raise ValueError('You cannot build your model by calling `build` '
- 'if your layers do not support float type inputs. '
- 'Instead, in order to instantiate and build your '
- 'model, `call` your model on real tensor data (of '
- 'the correct dtype).')
+ kwargs = {}
+ num_call_args = len(tf_inspect.getfullargspec(self.call).args)
+ if self._expects_training_arg and num_call_args == 3:
+ # Has call signature of call(self, input, training)
+ kwargs['training'] = False
+ elif num_call_args > 2:
+ # Has invalid call signature of call(self, input, *args, **kwargs)
+ raise ValueError('Currently, you cannot build your model if it has '
+ 'positional or keyword arguments that are not '
+ 'inputs to the model, but are required for its '
+ '`call` method. Instead, in order to instantiate '
+ 'and build your model, `call` your model on real '
+ 'tensor data with all expected call arguments.')
+
+ try:
+ self.call(x, **kwargs)
+ except (errors.InvalidArgumentError, TypeError):
+ raise ValueError('You cannot build your model by calling `build` '
+ 'if your layers do not support float type inputs. '
+ 'Instead, in order to instantiate and build your '
+ 'model, `call` your model on real tensor data (of '
+ 'the correct dtype).')
if self._layers:
self._track_layers(self._layers)
@@ -833,26 +819,26 @@ class Network(base_layer.Layer):
A tensor if there is a single output, or
a list of tensors if there are more than one outputs.
"""
- inputs = nest.flatten(inputs)
+ inputs = generic_utils.to_list(inputs)
if mask is None:
masks = [None for _ in range(len(inputs))]
else:
- masks = nest.flatten(mask)
-
- if not context.executing_eagerly():
- # Try to retrieve cached outputs if the layer has already been called
- # on these exact inputs.
- cache_key = (generic_utils.object_list_uid(inputs)
- + '_' + generic_utils.object_list_uid(masks))
- if cache_key in self._output_tensor_cache:
- # Cache hit.
- return self._output_tensor_cache[cache_key]
- # Actually apply the network graph to the new inputs.
+ masks = generic_utils.to_list(mask)
outputs, _ = self._run_internal_graph(inputs,
training=training,
mask=masks)
return outputs
+ def _call_and_compute_mask(self, inputs, training=None, mask=None):
+ inputs = generic_utils.to_list(inputs)
+ if mask is None:
+ masks = [None for _ in range(len(inputs))]
+ else:
+ masks = generic_utils.to_list(mask)
+ return self._run_internal_graph(inputs,
+ training=training,
+ mask=masks)
+
def compute_output_shape(self, input_shape):
if not self._is_graph_network:
if context.executing_eagerly():
@@ -878,9 +864,10 @@ class Network(base_layer.Layer):
' tensor inputs.')
cache_key = generic_utils.object_list_uid(input_shapes)
- if cache_key not in self._output_shape_cache:
- # Cache miss. We have to run the network graph manually (recursive calls
- # to `compute_output_shape`).
+ if cache_key in self._output_shape_cache:
+ # Cache hit.
+ output_shapes = self._output_shape_cache[cache_key]
+ else:
layers_to_output_shapes = {}
for i in range(len(input_shapes)):
layer = self._input_layers[i]
@@ -942,9 +929,6 @@ class Network(base_layer.Layer):
output_shapes.append(layers_to_output_shapes[shape_key])
# Store in cache.
self._output_shape_cache[cache_key] = output_shapes
- else:
- # Cache hit.
- output_shapes = self._output_shape_cache[cache_key]
if isinstance(output_shapes, list):
if len(output_shapes) == 1:
@@ -984,8 +968,6 @@ class Network(base_layer.Layer):
# Dictionary mapping reference tensors to tuples
# (computed tensor, compute mask)
# we assume a 1:1 mapping from tensor to mask
- # TODO(fchollet): raise exception when a `.compute_mask()` call
- # does not return a list the same size as `call`
tensor_map = {}
for x, y, mask in zip(self.inputs, inputs, masks):
tensor_map[str(id(x))] = (y, mask)
@@ -1014,53 +996,67 @@ class Network(base_layer.Layer):
kwargs = node.arguments
else:
kwargs = {}
+ # Ensure `training` arg propagation if applicable.
+ if 'training' in tf_inspect.getfullargspec(layer.call).args:
+ kwargs.setdefault('training', training)
+
if len(computed_data) == 1:
computed_tensor, computed_mask = computed_data[0]
# Ensure mask propagation if applicable.
- if 'mask' in tf_inspect.getargspec(layer.call).args:
+ if 'mask' in tf_inspect.getfullargspec(layer.call).args:
kwargs.setdefault('mask', computed_mask)
- if 'training' in tf_inspect.getargspec(layer.call).args:
- kwargs.setdefault('training', training)
-
- output_tensors = nest.flatten(
- layer.call(computed_tensor, **kwargs))
- if hasattr(layer, 'compute_mask'):
- output_masks = layer.compute_mask(computed_tensor,
- computed_mask)
- if output_masks is None:
- output_masks = [None for _ in output_tensors]
- else:
- output_masks = nest.flatten(output_masks)
+
+ # Compute outputs and masks.
+ if isinstance(layer, Network) and layer._is_graph_network:
+ output_tensors, output_masks = layer._call_and_compute_mask(
+ computed_tensor, **kwargs)
else:
- output_masks = [None for _ in output_tensors]
+ output_tensors = layer.call(computed_tensor, **kwargs)
+ if hasattr(layer, 'compute_mask'):
+ output_masks = layer.compute_mask(computed_tensor,
+ computed_mask)
+ else:
+ output_masks = [None for _ in output_tensors]
computed_tensors = [computed_tensor]
+
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
- if 'mask' in tf_inspect.getargspec(layer.call).args:
+ # Ensure mask propagation if applicable.
+ if 'mask' in tf_inspect.getfullargspec(layer.call).args:
kwargs.setdefault('mask', computed_masks)
- if 'training' in tf_inspect.getargspec(layer.call).args:
- kwargs.setdefault('training', training)
- output_tensors = nest.flatten(
- layer.call(computed_tensors, **kwargs))
-
- if hasattr(layer, 'compute_mask'):
- output_masks = layer.compute_mask(computed_tensors,
- computed_masks)
- if output_masks is None:
- output_masks = [None for _ in output_tensors]
- else:
- output_masks = nest.flatten(output_masks)
+ # Compute outputs and masks.
+ if isinstance(layer, Network) and layer._is_graph_network:
+ output_tensors, output_masks = layer._call_and_compute_mask(
+ computed_tensors, **kwargs)
else:
- output_masks = [None for _ in output_tensors]
+ output_tensors = layer.call(computed_tensors, **kwargs)
+ if hasattr(layer, 'compute_mask'):
+ output_masks = layer.compute_mask(computed_tensors,
+ computed_masks)
+ else:
+ output_masks = [None for _ in output_tensors]
+
+ output_tensors = generic_utils.to_list(output_tensors)
+ if output_masks is None:
+ output_masks = [None for _ in output_tensors]
+ else:
+ output_masks = generic_utils.to_list(output_masks)
if not context.executing_eagerly():
+ # Set mask metadata.
+ for x, m in zip(output_tensors, output_masks):
+ try:
+ x._keras_mask = m
+ except AttributeError:
+ pass
+
+ # Apply activity regularizer if any.
if layer.activity_regularizer is not None:
regularization_losses = [
layer.activity_regularizer(x) for x in output_tensors
]
- # Apply activity regularizer if any:
layer.add_loss(regularization_losses, computed_tensors)
# Update tensor_map.
@@ -1085,18 +1081,10 @@ class Network(base_layer.Layer):
if output_masks is not None:
output_masks = output_masks[0]
- if not context.executing_eagerly():
- # Update cache;
- # keys are based on ids on input tensors and inputs masks.
- cache_key = (generic_utils.object_list_uid(inputs)
- + '_' + generic_utils.object_list_uid(masks))
- self._output_tensor_cache[cache_key] = output_tensors
- self._output_mask_cache[cache_key] = output_masks
-
- if output_shapes is not None:
- input_shapes = [backend.int_shape(x) for x in inputs]
- cache_key = generic_utils.object_list_uid(input_shapes)
- self._output_shape_cache[cache_key] = output_shapes
+ if output_shapes is not None:
+ input_shapes = [backend.int_shape(x) for x in inputs]
+ cache_key = generic_utils.object_list_uid(input_shapes)
+ self._output_shape_cache[cache_key] = output_shapes
return output_tensors, output_masks
@@ -1439,6 +1427,16 @@ class Network(base_layer.Layer):
session = None
else:
session = backend.get_session()
+ optimizer = getattr(self, 'optimizer', None)
+ if (optimizer
+ and not isinstance(optimizer, checkpointable.CheckpointableBase)):
+ logging.warning(
+ ('This model was compiled with a Keras optimizer (%s) but is being '
+ 'saved in TensorFlow format with `save_weights`. The model\'s '
+ 'weights will be saved, but unlike with TensorFlow optimizers in '
+ 'the TensorFlow format the optimizer\'s state will not be '
+ 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
+ % (optimizer,))
self._checkpointable_saver.save(filepath, session=session)
def load_weights(self, filepath, by_name=False):
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index e029e614e0..f2f8a27b76 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training as training_module
try:
@@ -663,6 +664,22 @@ class SubclassedModel(training.Model):
class TestWeightSavingAndLoadingTFFormat(test.TestCase):
+ def test_keras_optimizer_warning(self):
+ graph = ops.Graph()
+ with graph.as_default(), self.test_session(graph):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='adam', metrics=['acc'])
+ model._make_train_function()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.save_weights(prefix)
+ self.assertRegexpMatches(
+ str(mock_log.call_args),
+ 'Keras optimizer')
+
@test_util.run_in_graph_and_eager_modes
def test_tensorflow_format_overwrite(self):
with self.test_session() as session:
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 34f74db6ef..079c8dae71 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -24,6 +24,7 @@ from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer as input_layer_lib
@@ -1068,6 +1069,101 @@ class DefaultShapeInferenceBehaviorTest(test.TestCase):
outputs = LayerWithAdditionalArg()(inputs, some_arg=0)
_ = keras.Model(inputs, outputs)
+ @test_util.run_in_graph_and_eager_modes()
+ def testNoneInShape(self):
+
+ class Model(keras.Model):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.conv1 = keras.layers.Conv2D(8, 3)
+ self.pool = keras.layers.GlobalAveragePooling2D()
+ self.fc = keras.layers.Dense(3)
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = self.fc(x)
+ return x
+
+ model = Model()
+ model.build(tensor_shape.TensorShape((None, None, None, 1)))
+ self.assertTrue(model.built, 'Model should be built')
+ self.assertTrue(model.weights,
+ 'Model should have its weights created as it '
+ 'has been built')
+ sample_input = array_ops.ones((1, 10, 10, 1))
+ output = model(sample_input)
+ self.assertEqual(output.shape, (1, 3))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testNoneInShapeWithCompoundModel(self):
+
+ class BasicBlock(keras.Model):
+
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+ self.conv1 = keras.layers.Conv2D(8, 3)
+ self.pool = keras.layers.GlobalAveragePooling2D()
+ self.dense = keras.layers.Dense(3)
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = self.dense(x)
+ return x
+
+ class CompoundModel(keras.Model):
+
+ def __init__(self):
+ super(CompoundModel, self).__init__()
+ self.block = BasicBlock()
+
+ def call(self, x):
+ x = self.block(x) # pylint: disable=not-callable
+ return x
+
+ model = CompoundModel()
+ model.build(tensor_shape.TensorShape((None, None, None, 1)))
+ self.assertTrue(model.built, 'Model should be built')
+ self.assertTrue(model.weights,
+ 'Model should have its weights created as it '
+ 'has been built')
+ sample_input = array_ops.ones((1, 10, 10, 1))
+ output = model(sample_input) # pylint: disable=not-callable
+ self.assertEqual(output.shape, (1, 3))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testNoneInShapeWithFunctinalAPI(self):
+
+ class BasicBlock(keras.Model):
+ # Inherting from keras.layers.Layer since we are calling this layer
+ # inside a model created using functional API.
+
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+ self.conv1 = keras.layers.Conv2D(8, 3)
+
+ def call(self, x):
+ x = self.conv1(x)
+ return x
+
+ input_layer = keras.layers.Input(shape=(None, None, 1))
+ x = BasicBlock()(input_layer)
+ x = keras.layers.GlobalAveragePooling2D()(x)
+ output_layer = keras.layers.Dense(3)(x)
+
+ model = keras.Model(inputs=input_layer, outputs=output_layer)
+
+ model.build(tensor_shape.TensorShape((None, None, None, 1)))
+ self.assertTrue(model.built, 'Model should be built')
+ self.assertTrue(model.weights,
+ 'Model should have its weights created as it '
+ 'has been built')
+ sample_input = array_ops.ones((1, 10, 10, 1))
+ output = model(sample_input)
+ self.assertEqual(output.shape, (1, 3))
+
class GraphUtilsTest(test.TestCase):
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index c8e5e497ae..3db0b4c8ad 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -31,7 +31,9 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import distributed_training_utils
from tensorflow.python.keras.engine import training_arrays
+from tensorflow.python.keras.engine import training_distributed
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
@@ -112,6 +114,9 @@ class Model(Network):
self._iterator_get_next = weakref.WeakKeyDictionary()
# Create a cache for dataset - uninitialized iterators
self._dataset_iterator_cache = weakref.WeakKeyDictionary()
+ # initializing _distribution_strategy here since it is possible to call
+ # predict on a model without compiling it.
+ self._distribution_strategy = None
def _set_sample_weight_attributes(self, sample_weight_mode,
skip_target_weighing_indices):
@@ -140,6 +145,7 @@ class Model(Network):
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
+ distribute=None,
**kwargs):
"""Configures the model for training.
@@ -183,12 +189,32 @@ class Model(Network):
can specify them via the `target_tensors` argument. It can be
a single tensor (for a single-output model), a list of tensors,
or a dict mapping output names to target tensors.
+ distribute: The DistributionStrategy instance that we want to use to
+ distribute the training of the model.
**kwargs: These arguments are passed to `tf.Session.run`.
Raises:
ValueError: In case of invalid arguments for
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
+ # Validate that arguments passed by the user to `compile` are supported by
+ # DistributionStrategy.
+ if distribute and not isinstance(
+ optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
+ raise ValueError('Only TF native optimizers are supported with '
+ 'DistributionStrategy.')
+ if distribute and context.executing_eagerly():
+ raise ValueError('DistributionStrategy is not supported in Eager mode.')
+ if distribute and sample_weight_mode:
+ raise ValueError('sample_weight_mode is not supported with '
+ 'DistributionStrategy.')
+ if distribute and weighted_metrics:
+ raise ValueError('weighted_metrics is not supported with '
+ 'DistributionStrategy.')
+ if distribute and target_tensors:
+ raise ValueError('target_tensors is not supported with '
+ 'DistributionStrategy.')
+
loss = loss or {}
if context.executing_eagerly() and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
@@ -211,6 +237,17 @@ class Model(Network):
raise ValueError('target_tensors is not supported in Eager mode.')
self.target_tensors = target_tensors
+ # Set DistributionStrategy specific parameters.
+ self._distribution_strategy = distribute
+ if self._distribution_strategy is not None:
+ self._grouped_model = self._compile_distributed_model(
+ self._distribution_strategy)
+ with self._distribution_strategy.scope():
+ first_replicated_model = self._distribution_strategy.unwrap(
+ self._grouped_model)[0]
+ # We initialize the callback model with the first replicated model.
+ self._replicated_model = DistributedCallbackModel(first_replicated_model)
+ self._replicated_model.set_original_model(self)
if not self.built:
# Model is not compilable because it does not know its number of inputs
# and outputs, nor their shapes and names. We will compile after the first
@@ -261,9 +298,7 @@ class Model(Network):
# Prepare output masks.
if not context.executing_eagerly():
- masks = self.compute_mask(self.inputs, mask=None)
- if masks is None:
- masks = [None for _ in self.outputs]
+ masks = [getattr(x, '_keras_mask', None) for x in self.outputs]
if not isinstance(masks, list):
masks = [masks]
@@ -482,6 +517,19 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
+ def _compile_distributed_model(self, distribution_strategy):
+ # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the
+ # model?
+ def _clone_model_per_tower(model):
+ new_model = training_distributed.clone_and_build_model(model)
+ return new_model
+
+ with distribution_strategy.scope():
+ # Create a copy of this model on each of the devices.
+ grouped_models = distribution_strategy.call_for_each_tower(
+ _clone_model_per_tower, self)
+ return grouped_models
+
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -572,6 +620,103 @@ class Model(Network):
self._iterator_get_next[iterator] = get_next_op
return get_next_op
+ def _distribution_standardize_user_data(self,
+ x,
+ y=None,
+ sample_weight=None,
+ class_weight=None,
+ batch_size=None,
+ check_steps=False,
+ steps_name='steps',
+ steps=None,
+ validation_split=0):
+ """Runs validation checks on input and target data passed by the user.
+
+ This is called when using DistributionStrategy to train, evaluate or serve
+ the model.
+
+ Args:
+ x: Input data. A `tf.data` dataset.
+ y: Since `x` is a dataset, `y` should not be specified
+ (since targets will be obtained from the iterator).
+ sample_weight: An optional sample-weight array passed by the user to
+ weight the importance of each sample in `x`.
+ class_weight: An optional class-weight array by the user to
+ weight the importance of samples in `x` based on the class they belong
+ to, as conveyed by `y`.
+ batch_size: Integer batch size. If provided, it is used to run additional
+ validation checks on stateful models.
+ check_steps: boolean, True if we want to check for validity of `steps` and
+ False, otherwise.
+ steps_name: The public API's parameter name for `steps`.
+ steps: Integer or `None`. Total number of steps (batches of samples) to
+ execute.
+ validation_split: Float between 0 and 1.
+ Fraction of the training data to be used as validation data.
+
+ Returns:
+ A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
+ If the model's input and targets are symbolic, these lists are empty
+ (since the model takes no user-provided data, instead the data comes
+ from the symbolic inputs/targets).
+
+ Raises:
+ ValueError: In case of invalid user-provided data.
+ RuntimeError: If the model was never compiled.
+ """
+ if sample_weight is not None and sample_weight.all():
+ raise ValueError('sample_weight is currently not supported when using '
+ 'DistributionStrategy.')
+ if class_weight:
+ raise ValueError('class_weight is currently not supported when using '
+ 'DistributionStrategy.')
+
+ # TODO(anjalisridhar): Can we use the iterator and getnext op cache?
+ # We require users to pass Datasets since we distribute the dataset across
+ # multiple devices.
+ if not isinstance(x, dataset_ops.Dataset):
+ raise ValueError('When using DistributionStrategy you must specify a '
+ 'Dataset object instead of a %s.' % type(x))
+ # TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a
+ # function which returns a Dataset. Currently distribute_dataset() only
+ # accepts a function that returns a Dataset. Once we add support for being
+ # able to clone a Dataset on multiple workers we can remove this lambda.
+ result = self._distribution_strategy.distribute_dataset(lambda: x)
+ iterator = result.make_initializable_iterator()
+ K.get_session().run(iterator.initializer)
+ # Validates `steps` argument based on x's type.
+ if check_steps:
+ if steps is None:
+ raise ValueError('When using a Dataset instance as input to a model, '
+ 'you should specify the `{steps_name}` argument.'
+ .format(steps_name=steps_name))
+
+ training_utils.validate_iterator_input(x, y, sample_weight,
+ validation_split)
+ # x an y may be PerDevice objects with an input and output tensor
+ # corresponding to each device. For example, x could be
+ # PerDevice:{device: get_next tensor,...}.
+ next_element = iterator.get_next()
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s' % next_element)
+ x, y = next_element
+ # Validate that all the elements in x and y are of the same type and shape.
+ # We can then pass the first element of x and y to `_standardize_weights`
+ # below and be confident of the output. We need to reopen the scope since
+ # we unwrap values when we validate x and y.
+ with self._distribution_strategy.scope():
+ x_values, y_values = distributed_training_utils.\
+ validate_distributed_dataset_inputs(self._distribution_strategy, x, y)
+
+ _, _, sample_weights = self._standardize_weights(x_values[0],
+ y_values[0],
+ sample_weight,
+ class_weight,
+ batch_size)
+ return x, y, sample_weights
+
def _standardize_user_data(self,
x,
y=None,
@@ -634,6 +779,18 @@ class Model(Network):
ValueError: In case of invalid user-provided data.
RuntimeError: If the model was never compiled.
"""
+ if self._distribution_strategy:
+ return self._distribution_standardize_user_data(
+ x,
+ y,
+ sample_weight=sample_weight,
+ class_weight=class_weight,
+ batch_size=batch_size,
+ check_steps=check_steps,
+ steps_name=steps_name,
+ steps=steps,
+ validation_split=validation_split)
+
if isinstance(x, dataset_ops.Dataset):
if context.executing_eagerly():
x = x.make_one_shot_iterator()
@@ -682,7 +839,12 @@ class Model(Network):
raise ValueError('Please provide data as a list or tuple of 2 elements '
' - input and target pair. Received %s' % next_element)
x, y = next_element
+ x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
+ class_weight, batch_size)
+ return x, y, sample_weights
+ def _standardize_weights(self, x, y, sample_weight=None, class_weight=None,
+ batch_size=None,):
# First, we build/compile the model on the fly if necessary.
all_inputs = []
is_build_called = False
@@ -844,11 +1006,12 @@ class Model(Network):
feed_sample_weight_modes)
]
# Check that all arrays have the same length.
- training_utils.check_array_lengths(x, y, sample_weights)
- if self._is_graph_network and not context.executing_eagerly():
- # Additional checks to avoid users mistakenly using improper loss fns.
- training_utils.check_loss_and_target_compatibility(
- y, self._feed_loss_fns, feed_output_shapes)
+ if not self._distribution_strategy:
+ training_utils.check_array_lengths(x, y, sample_weights)
+ if self._is_graph_network and not context.executing_eagerly():
+ # Additional checks to avoid users mistakenly using improper loss fns.
+ training_utils.check_loss_and_target_compatibility(
+ y, self._feed_loss_fns, feed_output_shapes)
else:
y = []
sample_weights = []
@@ -1186,6 +1349,9 @@ class Model(Network):
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
# Validate and standardize user data.
+ if self._distribution_strategy:
+ distributed_training_utils.validate_callbacks(callbacks)
+
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1266,6 +1432,17 @@ class Model(Network):
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps)
+ elif self._distribution_strategy:
+ return training_distributed.fit_loop(
+ self, x, y,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ val_inputs=val_x,
+ val_targets=val_y,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
else:
return training_arrays.fit_loop(
self, x, y,
@@ -1358,12 +1535,29 @@ class Model(Network):
if context.executing_eagerly():
return training_eager.test_loop(
- self, inputs=x, targets=y, sample_weights=sample_weights,
- batch_size=batch_size, verbose=verbose, steps=steps)
+ self,
+ inputs=x,
+ targets=y,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ verbose=verbose,
+ steps=steps)
+ elif self._distribution_strategy:
+ return training_distributed.test_loop(
+ self,
+ inputs=x,
+ targets=y,
+ verbose=verbose,
+ steps=steps)
else:
return training_arrays.test_loop(
- self, inputs=x, targets=y, sample_weights=sample_weights,
- batch_size=batch_size, verbose=verbose, steps=steps)
+ self,
+ inputs=x,
+ targets=y,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ verbose=verbose,
+ steps=steps)
def predict(self, x, batch_size=None, verbose=0, steps=None):
"""Generates output predictions for the input samples.
@@ -1408,6 +1602,9 @@ class Model(Network):
if context.executing_eagerly():
return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
+ elif self._distribution_strategy:
+ return training_distributed.predict_loop(
+ self, x, verbose=verbose, steps=steps)
else:
return training_arrays.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
@@ -1455,6 +1652,9 @@ class Model(Network):
Raises:
ValueError: In case of invalid user-provided arguments.
"""
+ if self._distribution_strategy:
+ raise ValueError('`train_on_batch` is not supported for models '
+ 'compiled with DistributionStrategy.')
# Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight, class_weight=class_weight)
@@ -1511,6 +1711,9 @@ class Model(Network):
Raises:
ValueError: In case of invalid user-provided arguments.
"""
+ if self._distribution_strategy:
+ raise ValueError('`test_on_batch` is not supported for models '
+ 'compiled with DistributionStrategy.')
# Validate and standardize user data.
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight)
@@ -1548,6 +1751,9 @@ class Model(Network):
ValueError: In case of mismatch between given number of inputs and
expectations of the model.
"""
+ if self._distribution_strategy:
+ raise ValueError('`predict_on_batch` is not supported for models '
+ 'compiled with DistributionStrategy.')
# Validate and standardize user data.
inputs, _, _ = self._standardize_user_data(x)
if context.executing_eagerly():
@@ -1811,3 +2017,45 @@ class Model(Network):
workers=workers,
use_multiprocessing=use_multiprocessing,
verbose=verbose)
+
+
+class DistributedCallbackModel(Model):
+ """Model that is used for callbacks with DistributionStrategy."""
+
+ def __init__(self, model):
+ super(DistributedCallbackModel, self).__init__()
+ # TODO(anjalisridhar): Right now the only attributes set are the layer and
+ # weights. We may need to set additional attributes as needed since we have
+ # not called compile on this model.
+
+ def set_original_model(self, orig_model):
+ self._original_model = orig_model
+
+ def save_weights(self, filepath, overwrite=True, save_format=None):
+ self._replicated_model.save_weights(filepath, overwrite=overwrite,
+ save_format=save_format)
+
+ def save(self, filepath, overwrite=True, include_optimizer=True):
+ # save weights from the distributed model to the original model
+ distributed_model_weights = self.get_weights()
+ self._original_model.set_weights(distributed_model_weights)
+ # TODO(anjalisridhar): Do we need to save the original model here?
+ # Saving the first replicated model works as well.
+ self._original_model.save(filepath, overwrite=True, include_optimizer=False)
+
+ def load_weights(self, filepath, by_name=False):
+ self._original_model.load_weights(filepath, by_name=False)
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = self._original_model.get_weights()
+ distributed_training_utils.set_weights(
+ self._original_model._distribution_strategy, self, # pylint: disable=protected-access
+ orig_model_weights)
+
+ def __getattr__(self, item):
+ # Whitelisted atttributes of the model that can be accessed by the user
+ # during a callback.
+ if item not in ['_setattr_tracking']:
+ logging.warning('You are accessing attribute ' + item + 'of the'
+ 'DistributedCallbackModel that may not have been set'
+ 'correctly.')
+
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index 6572e2c344..d24f4b64b9 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -200,7 +200,9 @@ def fit_loop(model,
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
+ 'batches (in this case, %d batches). You may need to'
+ 'use the repeat() function when building your '
+ 'dataset.' %
steps_per_epoch * epochs)
break
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
new file mode 100644
index 0000000000..75e466d593
--- /dev/null
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -0,0 +1,455 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Part of the Keras training engine related to distributed training.
+"""
+# pylint: disable=protected-access
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import copy
+import numpy as np
+from tensorflow.python.framework import errors
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import optimizers
+from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.platform import tf_logging as logging
+
+
+def fit_loop(
+ model,
+ inputs,
+ targets,
+ epochs=100,
+ verbose=1,
+ callbacks=None,
+ val_inputs=None,
+ val_targets=None,
+ callback_metrics=None,
+ initial_epoch=0,
+ steps_per_epoch=None,
+ validation_steps=None):
+ """fit function when using DistributionStrategy for training.
+
+ Arguments:
+ model: Keras Model instance.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
+ val_inputs: List of input arrays.
+ val_targets: List of target arrays.
+ callback_metrics: List of strings, the display names of the metrics
+ passed to the callbacks. They should be the
+ concatenation of list the display names of the outputs of
+ `f` and the list of display names of the outputs of `f_val`.
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run)
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch. Ignored with the default value of `None`.
+ validation_steps: Number of steps to run validation for
+ (only if doing validation from data tensors).
+ Ignored with the default value of `None`.
+
+ Returns:
+ `History` object.
+
+ Raises:
+ ValueError: in case of invalid arguments.
+ """
+ current_strategy = model._distribution_strategy
+ def _per_device_train_function(model):
+ model._make_train_function()
+ return (model.train_function.inputs,
+ model.train_function.outputs,
+ model.train_function.updates_op,
+ model.train_function.session_kwargs)
+
+ with current_strategy.scope():
+ # Create train ops on each of the devices when we call
+ # `_per_device_train_function`.
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_train_function, model._grouped_model)
+ # Unwrap all the per device values returned from `call_for_each_tower`.
+ # Unwrapping per device values gives you a list of values that can be
+ # used to construct a new train function that is composed of update ops on
+ # all the devices over which the model is distributed.
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs,
+ grouped_updates, grouped_session_args, with_loss_tensor=True)
+
+ # Dataset inputs and targets are also per devices values that need to be
+ # unwrapped.
+ dataset_inputs = distributed_training_utils.flatten_perdevice_values(
+ current_strategy, inputs)
+ dataset_targets = distributed_training_utils.flatten_perdevice_values(
+ current_strategy, targets)
+
+ # Create a train function that is composed of all the parameters above.
+ distributed_train_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
+
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [1]
+ else:
+ ins = dataset_inputs + dataset_targets
+
+ do_validation = False
+ if validation_steps:
+ do_validation = True
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` '
+ 'when doing step-wise '
+ 'training, i.e. `steps_per_epoch` '
+ 'must be set.')
+ out_labels = model.metrics_names
+ if do_validation:
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ callback_metrics = copy.copy(out_labels)
+
+ model.history = cbks.History()
+ all_callbacks = [cbks.BaseLogger(
+ stateful_metrics=model.stateful_metric_names)]
+ if verbose:
+ # We assume that `steps_per_epoch` is always set since we have to use
+ # Datasets.
+ count_mode = 'steps'
+
+ all_callbacks.append(
+ cbks.ProgbarLogger(
+ count_mode, stateful_metrics=model.stateful_metric_names))
+ all_callbacks += (callbacks or []) + [model.history]
+ callbacks = cbks.CallbackList(all_callbacks)
+ out_labels = out_labels or []
+
+ # We set the callback model to an instance of the `DistributedModel` that we
+ # create in the `compile` call. The `DistributedModel` is initialized with
+ # the first replicated model. We need to set the callback model to a
+ # DistributedModel to allow us to override saving and loading weights when
+ # we checkpoint the model during training.
+ callback_model = model._replicated_model
+
+ callbacks.set_model(callback_model)
+
+ callbacks.set_params({
+ 'epochs': epochs,
+ 'steps': steps_per_epoch,
+ 'samples': None,
+ 'verbose': verbose,
+ 'do_validation': do_validation,
+ 'metrics': callback_metrics or [],
+ })
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
+
+ out_labels = out_labels or []
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ for epoch in range(initial_epoch, epochs):
+ callbacks.on_epoch_begin(epoch)
+ if steps_per_epoch is not None:
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {}
+ batch_logs['batch'] = step_index
+ batch_logs['size'] = 1
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ # TODO(anjalisridhar): Temporary workaround for aggregating metrics
+ # across towers. Replace with the new metrics module eventually.
+ merged_output = []
+ # The first output is the total loss.
+ merged_output.append(outs[0])
+ current_index = 1
+ num_devices = len(current_strategy._devices)
+ # Each label in `out_labels` corresponds to one set of metrics. The
+ # number of metric values corresponds to the number of devices. We
+ # currently take the mean of the values.
+ for _ in out_labels[1:]:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callback_model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_inputs,
+ val_targets,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callback_model.stop_training:
+ break
+ callbacks.on_train_end()
+
+ # Copy the weights back from the replicated model to the original model.
+ with current_strategy.scope():
+ updated_weights = current_strategy.unwrap(
+ model._grouped_model)[0].get_weights()
+ model.set_weights(updated_weights)
+ return model.history
+
+
+def test_loop(model, inputs, targets, verbose=0, steps=None):
+ """evaluate method to validate a model that uses DistributionStrategy.
+
+ Arguments:
+ model: Keras Model instance.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ verbose: verbosity mode.
+ steps: Total number of steps (batches of samples)
+ before declaring predictions finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+ """
+ current_strategy = model._distribution_strategy
+ def _per_device_test_function(model):
+ model._make_test_function()
+ return (model.test_function.inputs,
+ model.test_function.outputs,
+ model.test_function.updates_op,
+ model.test_function.session_kwargs)
+
+ with current_strategy.scope():
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_test_function, model._grouped_model)
+
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args, with_loss_tensor=True)
+
+ dataset_inputs = distributed_training_utils.flatten_perdevice_values(
+ current_strategy, inputs)
+ dataset_targets = distributed_training_utils.flatten_perdevice_values(
+ current_strategy, targets)
+
+ distributed_test_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
+
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [0]
+ else:
+ ins = dataset_inputs + dataset_targets
+
+ if hasattr(model, 'metrics'):
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
+ else:
+ stateful_metric_indices = []
+
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ if steps is not None:
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ for _ in enumerate(batch_outs):
+ outs.append(0.)
+ for i, batch_out in enumerate(batch_outs):
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose == 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+def predict_loop(model, inputs, verbose=0, steps=None):
+ """Abstract method to loop over some data in batches.
+
+ Arguments:
+ model: Keras Model instance.
+ inputs: list of tensors to be fed to `f`.
+ verbose: verbosity mode.
+ steps: Total number of steps (batches of samples)
+ before declaring `_predict_loop` finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions
+ (if the model has multiple outputs).
+ """
+ current_strategy = model._distribution_strategy
+ def _per_device_predict_function(model):
+ model._make_predict_function()
+ return (model.predict_function.inputs,
+ model.predict_function.outputs,
+ model.predict_function.updates_op,
+ model.predict_function.session_kwargs)
+
+ with current_strategy.scope():
+ (grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args) = current_strategy.call_for_each_tower(
+ _per_device_predict_function, model._grouped_model)
+
+ (all_inputs, all_outputs, all_updates,
+ all_session_args) = distributed_training_utils.unwrap_values(
+ current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+ grouped_session_args)
+
+ dataset_inputs = distributed_training_utils.flatten_perdevice_values(
+ current_strategy, inputs)
+
+ distributed_predict_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
+
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + [0]
+ else:
+ ins = dataset_inputs
+
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
+ with current_strategy.scope():
+ distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+ distributed_training_utils.set_weights(
+ current_strategy, distributed_model, orig_model_weights)
+
+ if steps is not None:
+ # Since we do not know how many samples we will see, we cannot pre-allocate
+ # the returned Numpy arrays. Instead, we store one array per batch seen
+ # and concatenate them upon returning.
+ unconcatenated_outs = []
+ for step in range(steps):
+ batch_outs = distributed_predict_function(ins)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step == 0:
+ for _ in batch_outs:
+ unconcatenated_outs.append([])
+ for i, batch_out in enumerate(batch_outs):
+ unconcatenated_outs[i].append(batch_out)
+ if verbose == 1:
+ progbar.update(step + 1)
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
+
+
+def clone_and_build_model(model):
+ """Clone and build the given keras_model."""
+ # We need to set the import here since we run into a circular dependency
+ # error.
+ from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top
+ cloned_model = models.clone_model(model, input_tensors=None)
+
+ # Compile and build model.
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ optimizer = model.optimizer
+ else:
+ optimizer_config = model.optimizer.get_config()
+ optimizer = model.optimizer.__class__.from_config(optimizer_config)
+
+ cloned_model.compile(
+ optimizer,
+ model.loss,
+ metrics=model.metrics,
+ loss_weights=model.loss_weights,
+ sample_weight_mode=model.sample_weight_mode,
+ weighted_metrics=model.weighted_metrics)
+ return cloned_model
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index c7fcf34981..774d2e44f3 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -92,21 +92,23 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
applies masking and sample weighting to the loss value.
"""
total_loss = 0
+ kwargs = {}
+ if model._expects_training_arg:
+ kwargs['training'] = training
if len(inputs) == 1:
- if model._expects_training_arg:
- outs = model.call(inputs[0], training=training)
- else:
- outs = model.call(inputs[0])
+ inputs = inputs[0]
+
+ if model._is_graph_network:
+ outs, masks = model._call_and_compute_mask(inputs, **kwargs)
+ masks = generic_utils.to_list(masks)
else:
- if model._expects_training_arg:
- outs = model.call(inputs, training=training)
- else:
- outs = model.call(inputs)
- if not isinstance(outs, list):
- outs = [outs]
+ outs = model.call(inputs, **kwargs)
+ masks = None
- if not isinstance(targets, list):
- targets = [targets]
+ outs = generic_utils.to_list(outs)
+ if masks is None:
+ masks = [None for _ in outs]
+ targets = generic_utils.to_list(targets)
loss_metrics = []
with backend.name_scope('loss'):
@@ -115,10 +117,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
weights = sample_weights[i]
else:
weights = None
-
- # TODO(fchollet): support masking; in practice `_keras_mask` is never
- # set in this context currently.
- mask = outs[i]._keras_mask
+ mask = masks[i]
weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
with backend.name_scope(model.output_names[i] + '_loss'):
@@ -220,10 +219,11 @@ def iterator_fit_loop(model,
next_element = inputs.get_next()
except errors.OutOfRangeError:
logging.warning(
- 'Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset'
- ' can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' % steps_per_epoch * epochs)
+ 'Your dataset iterator ran out of data; interrupting training. Make '
+ 'sure that your dataset can generate at least '
+ '`steps_per_epoch * epochs` batches (in this case, %d batches). You '
+ 'may need to use the repeat() function when building your '
+ 'dataset.' % steps_per_epoch * epochs)
break
if len(inputs.output_shapes) == 2:
@@ -335,7 +335,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
logging.warning(
'Your dataset iterator ran out of data interrupting testing. '
'Make sure that your dataset can generate at least `steps` batches '
- '(in this case, %d batches).', steps)
+ '(in this case, %d batches). You may need to use the repeat() '
+ 'function when building your dataset.', steps)
break
if len(inputs.output_shapes) == 2:
@@ -426,10 +427,10 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
next_element = inputs.get_next()
except errors.OutOfRangeError:
logging.warning(
- 'Your dataset iterator ran out of data; '
- 'interrupting prediction. Make sure that your '
- 'dataset can generate at least `steps` '
- 'batches (in this case, %d batches).', steps)
+ 'Your dataset iterator ran out of data; interrupting prediction. '
+ 'Make sure that your dataset can generate at least `steps` batches '
+ '(in this case, %d batches). You may need to use the repeat() '
+ 'function when building your dataset.', steps)
break
# expects a tuple, where first element of tuple represents inputs
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index 605c1935a5..56f321732f 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -167,27 +167,6 @@ class CorrectnessTest(test.TestCase):
np.around(history.history['loss'][-1], decimals=4), 0.6173)
@tf_test_util.run_in_graph_and_eager_modes
- def test_metrics_correctness(self):
- model = keras.Sequential()
- model.add(keras.layers.Dense(3,
- activation='relu',
- input_dim=4,
- kernel_initializer='ones'))
- model.add(keras.layers.Dense(1,
- activation='sigmoid',
- kernel_initializer='ones'))
- model.compile(loss='mae',
- metrics=['acc'],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
- x = np.ones((100, 4))
- y = np.ones((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 1.)
- y = np.zeros((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 0.)
-
- @tf_test_util.run_in_graph_and_eager_modes
def test_loss_correctness_with_iterator(self):
# Test that training loss is the same in eager and graph
# (by comparing it to a reference value in a deterministic case)
@@ -210,35 +189,6 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
- @tf_test_util.run_in_graph_and_eager_modes
- def test_metrics_correctness_with_iterator(self):
- model = keras.Sequential()
- model.add(
- keras.layers.Dense(
- 8, activation='relu', input_dim=4, kernel_initializer='ones'))
- model.add(
- keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones'))
- model.compile(
- loss='binary_crossentropy',
- metrics=['accuracy'],
- optimizer=RMSPropOptimizer(learning_rate=0.001))
- np.random.seed(123)
- x = np.random.randint(10, size=(100, 4)).astype(np.float32)
- y = np.random.randint(2, size=(100, 1)).astype(np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
- outs = model.evaluate(iterator, steps=10)
- self.assertEqual(np.around(outs[1], decimals=1), 0.5)
-
- y = np.zeros((100, 1), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
- iterator = dataset.make_one_shot_iterator()
- outs = model.evaluate(iterator, steps=10)
- self.assertEqual(outs[1], 0.)
-
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index e89847dbbf..753519fbac 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -793,19 +793,48 @@ class LossWeightingTest(test.TestCase):
class LossMaskingTest(test.TestCase):
+ @tf_test_util.run_in_graph_and_eager_modes
def test_masking(self):
with self.test_session():
- np.random.seed(1337)
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
model.add(
keras.layers.TimeDistributed(
keras.layers.Dense(1, kernel_initializer='one')))
- model.compile(loss='mse', optimizer='sgd')
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
y = np.array([[[1], [1]], [[1], [1]]])
loss = model.train_on_batch(x, y)
- self.assertEqual(loss, 0)
+ self.assertEqual(float(loss), 0.)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_mask_argument_in_layer(self):
+ # Test that the mask argument gets correctly passed to a layer in the
+ # functional API.
+
+ class CustomMaskedLayer(keras.layers.Layer):
+
+ def __init__(self):
+ super(CustomMaskedLayer, self).__init__()
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ assert mask is not None
+ return inputs
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ with self.test_session():
+ x = np.random.random((5, 3))
+ inputs = keras.layers.Input((3,))
+ masked = keras.layers.Masking(mask_value=0)(inputs)
+ outputs = CustomMaskedLayer()(masked)
+
+ model = keras.Model(inputs, outputs)
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+ y = np.random.random((5, 3))
+ model.train_on_batch(x, y)
def test_loss_masking(self):
with self.test_session():
@@ -2061,5 +2090,91 @@ class TestTrainingWithDataset(test.TestCase):
model.train_on_batch(dataset)
+class TestTrainingWithMetrics(test.TestCase):
+ """Training tests related to metrics."""
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_metrics_correctness(self):
+ with self.test_session():
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 3, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(
+ 1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='mae',
+ metrics=['accuracy'],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ # verify correctness of stateful and stateless metrics.
+ x = np.ones((100, 4))
+ y = np.ones((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 1.)
+
+ y = np.zeros((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 0.)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_metrics_correctness_with_iterator(self):
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(
+ 8, activation='relu', input_dim=4, kernel_initializer='ones'))
+ model.add(
+ keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones'))
+ model.compile(
+ loss='binary_crossentropy',
+ metrics=['accuracy'],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+
+ np.random.seed(123)
+ x = np.random.randint(10, size=(100, 4)).astype(np.float32)
+ y = np.random.randint(2, size=(100, 1)).astype(np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(np.around(outs[1], decimals=1), 0.5)
+
+ y = np.zeros((100, 1), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+ iterator = dataset.make_one_shot_iterator()
+ outs = model.evaluate(iterator, steps=10)
+ self.assertEqual(outs[1], 0.)
+
+ def test_metrics_correctness_with_weighted_metrics(self):
+ with self.test_session():
+ np.random.seed(1337)
+ x = np.array([[[1.], [1.]], [[0.], [0.]]])
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(1, kernel_initializer='ones'),
+ input_shape=(2, 1)))
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ loss='mse',
+ sample_weight_mode='temporal',
+ weighted_metrics=['accuracy'])
+ y = np.array([[[1.], [1.]], [[1.], [1.]]])
+
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs, [0.5, 0.5])
+
+ w = np.array([[0., 0.], [0., 0.]])
+ outs = model.evaluate(x, y, sample_weight=w)
+ self.assertEqual(outs, [0., 0.])
+
+ w = np.array([[3., 4.], [1., 2.]])
+ outs = model.evaluate(x, y, sample_weight=w)
+ self.assertArrayNear(outs, [0.3, 0.7], .001)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index b304cb9093..38b64e69ec 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -33,6 +33,7 @@ from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import weights_broadcast_ops
def _map_nested(data, func):
@@ -577,15 +578,25 @@ def weighted_masked_objective(fn):
# to the number of unmasked samples.
score_array /= K.mean(mask)
- # apply sample weighting
+ # Apply sample weighting.
if weights is not None:
- # reduce score_array to same ndim as weight array
- ndim = K.ndim(score_array)
- weight_ndim = K.ndim(weights)
- score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
- score_array *= weights
- score_array /= K.mean(
- math_ops.cast(math_ops.not_equal(weights, 0), K.floatx()))
+
+ # Update dimensions of weights to match with values if possible.
+ score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
+ score_array, None, weights)
+ try:
+ # Broadcast weights if possible.
+ weights = weights_broadcast_ops.broadcast_weights(weights, score_array)
+ except ValueError:
+ # Reduce values to same ndim as weight array.
+ ndim = K.ndim(score_array)
+ weight_ndim = K.ndim(weights)
+ score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
+
+ score_array = math_ops.multiply(score_array, weights)
+ score_array = math_ops.reduce_sum(score_array)
+ weights = math_ops.reduce_sum(weights)
+ score_array = metrics_module.safe_div(score_array, weights)
return K.mean(score_array)
return weighted
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index 57f660b6d5..afef997b00 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -183,6 +183,7 @@ class GRULayerTest(test.TestCase):
self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
self.assertEqual(layer.cell.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_GRU(self):
layer_class = keras.layers.GRU
with self.test_session():
@@ -192,7 +193,8 @@ class GRULayerTest(test.TestCase):
model = keras.models.Sequential()
model.add(keras.layers.Masking(input_shape=(3, 4)))
model.add(layer_class(units=5, return_sequences=True, unroll=False))
- model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(0.01))
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
def test_from_config_GRU(self):
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index ae381f5955..9802820fd0 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -197,6 +197,7 @@ class LSTMLayerTest(test.TestCase):
self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
self.assertEqual(layer.cell.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_LSTM(self):
layer_class = keras.layers.LSTM
with self.test_session():
@@ -206,7 +207,8 @@ class LSTMLayerTest(test.TestCase):
model = keras.models.Sequential()
model.add(keras.layers.Masking(input_shape=(3, 4)))
model.add(layer_class(units=5, return_sequences=True, unroll=False))
- model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(0.01))
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
def test_from_config_LSTM(self):
@@ -311,7 +313,8 @@ class LSTMLayerTest(test.TestCase):
output = keras.layers.LSTM(units)(inputs, initial_state=initial_state)
model = keras.models.Model([inputs] + initial_state, output)
- model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(0.01))
inputs = np.random.random((num_samples, timesteps, embedding_dim))
initial_state = [np.random.random((num_samples, units))
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 534c0eca08..a8bfdf25f2 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -23,7 +23,6 @@ import numbers
import numpy as np
from tensorflow.python.eager import context
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
@@ -2231,342 +2230,6 @@ def _generate_dropout_mask(ones, rate, training=None, count=1):
return K.in_train_phase(dropped_inputs, ones, training=training)
-class Recurrent(Layer):
- """Deprecated abstract base class for recurrent layers.
-
- It still exists because it is leveraged by the convolutional-recurrent layers.
- It will be removed entirely in the future.
- It was never part of the public API.
- Do not use.
-
- Arguments:
- weights: list of Numpy arrays to set as initial weights.
- The list should have 3 elements, of shapes:
- `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
- return_sequences: Boolean. Whether to return the last output
- in the output sequence, or the full sequence.
- return_state: Boolean. Whether to return the last state
- in addition to the output.
- go_backwards: Boolean (default False).
- If True, process the input sequence backwards and return the
- reversed sequence.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
- unroll: Boolean (default False).
- If True, the network will be unrolled,
- else a symbolic loop will be used.
- Unrolling can speed-up a RNN,
- although it tends to be more memory-intensive.
- Unrolling is only suitable for short sequences.
- implementation: one of {0, 1, or 2}.
- If set to 0, the RNN will use
- an implementation that uses fewer, larger matrix products,
- thus running faster on CPU but consuming more memory.
- If set to 1, the RNN will use more matrix products,
- but smaller ones, thus running slower
- (may actually be faster on GPU) while consuming less memory.
- If set to 2 (LSTM/GRU only),
- the RNN will combine the input gate,
- the forget gate and the output gate into a single matrix,
- enabling more time-efficient parallelization on the GPU.
- Note: RNN dropout must be shared for all gates,
- resulting in a slightly reduced regularization.
- input_dim: dimensionality of the input (integer).
- This argument (or alternatively, the keyword argument `input_shape`)
- is required when using this layer as the first layer in a model.
- input_length: Length of input sequences, to be specified
- when it is constant.
- This argument is required if you are going to connect
- `Flatten` then `Dense` layers upstream
- (without it, the shape of the dense outputs cannot be computed).
- Note that if the recurrent layer is not the first layer
- in your model, you would need to specify the input length
- at the level of the first layer
- (e.g. via the `input_shape` argument)
-
- Input shape:
- 3D tensor with shape `(batch_size, timesteps, input_dim)`,
- (Optional) 2D tensors with shape `(batch_size, output_dim)`.
-
- Output shape:
- - if `return_state`: a list of tensors. The first tensor is
- the output. The remaining tensors are the last states,
- each with shape `(batch_size, units)`.
- - if `return_sequences`: 3D tensor with shape
- `(batch_size, timesteps, units)`.
- - else, 2D tensor with shape `(batch_size, units)`.
-
- # Masking
- This layer supports masking for input data with a variable number
- of timesteps. To introduce masks to your data,
- use an `Embedding` layer with the `mask_zero` parameter
- set to `True`.
-
- # Note on using statefulness in RNNs
- You can set RNN layers to be 'stateful', which means that the states
- computed for the samples in one batch will be reused as initial states
- for the samples in the next batch. This assumes a one-to-one mapping
- between samples in different successive batches.
-
- To enable statefulness:
- - specify `stateful=True` in the layer constructor.
- - specify a fixed batch size for your model, by passing
- if sequential model:
- `batch_input_shape=(...)` to the first layer in your model.
- else for functional model with 1 or more Input layers:
- `batch_shape=(...)` to all the first layers in your model.
- This is the expected shape of your inputs
- *including the batch size*.
- It should be a tuple of integers, e.g. `(32, 10, 100)`.
- - specify `shuffle=False` when calling fit().
-
- To reset the states of your model, call `.reset_states()` on either
- a specific layer, or on your entire model.
-
- # Note on specifying the initial state of RNNs
- You can specify the initial state of RNN layers symbolically by
- calling them with the keyword argument `initial_state`. The value of
- `initial_state` should be a tensor or list of tensors representing
- the initial state of the RNN layer.
-
- You can specify the initial state of RNN layers numerically by
- calling `reset_states` with the keyword argument `states`. The value of
- `states` should be a numpy array or list of numpy arrays representing
- the initial state of the RNN layer.
- """
-
- def __init__(self,
- return_sequences=False,
- return_state=False,
- go_backwards=False,
- stateful=False,
- unroll=False,
- implementation=0,
- **kwargs):
- super(Recurrent, self).__init__(**kwargs)
- self.return_sequences = return_sequences
- self.return_state = return_state
- self.go_backwards = go_backwards
- self.stateful = stateful
- self.unroll = unroll
- self.implementation = implementation
- self.supports_masking = True
- self.input_spec = [InputSpec(ndim=3)]
- self.state_spec = None
- self.dropout = 0
- self.recurrent_dropout = 0
-
- @tf_utils.shape_type_conversion
- def compute_output_shape(self, input_shape):
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- if self.return_sequences:
- output_shape = (input_shape[0], input_shape[1], self.units)
- else:
- output_shape = (input_shape[0], self.units)
-
- if self.return_state:
- state_shape = [tensor_shape.TensorShape(
- (input_shape[0], self.units)) for _ in self.states]
- return [tensor_shape.TensorShape(output_shape)] + state_shape
- return tensor_shape.TensorShape(output_shape)
-
- def compute_mask(self, inputs, mask):
- if isinstance(mask, list):
- mask = mask[0]
- output_mask = mask if self.return_sequences else None
- if self.return_state:
- state_mask = [None for _ in self.states]
- return [output_mask] + state_mask
- return output_mask
-
- def step(self, inputs, states):
- raise NotImplementedError
-
- def get_constants(self, inputs, training=None):
- return []
-
- def get_initial_state(self, inputs):
- # build an all-zero tensor of shape (samples, output_dim)
- initial_state = array_ops.zeros_like(inputs)
- # shape of initial_state = (samples, timesteps, input_dim)
- initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2))
- # shape of initial_state = (samples,)
- initial_state = array_ops.expand_dims(initial_state, axis=-1)
- # shape of initial_state = (samples, 1)
- initial_state = K.tile(initial_state, [1,
- self.units]) # (samples, output_dim)
- initial_state = [initial_state for _ in range(len(self.states))]
- return initial_state
-
- def preprocess_input(self, inputs, training=None):
- return inputs
-
- def __call__(self, inputs, initial_state=None, **kwargs):
- if (isinstance(inputs, (list, tuple)) and
- len(inputs) > 1
- and initial_state is None):
- initial_state = inputs[1:]
- inputs = inputs[0]
-
- # If `initial_state` is specified,
- # and if it a Keras tensor,
- # then add it to the inputs and temporarily
- # modify the input spec to include the state.
- if initial_state is None:
- return super(Recurrent, self).__call__(inputs, **kwargs)
-
- if not isinstance(initial_state, (list, tuple)):
- initial_state = [initial_state]
-
- is_keras_tensor = hasattr(initial_state[0], '_keras_history')
- for tensor in initial_state:
- if hasattr(tensor, '_keras_history') != is_keras_tensor:
- raise ValueError('The initial state of an RNN layer cannot be'
- ' specified with a mix of Keras tensors and'
- ' non-Keras tensors')
-
- if is_keras_tensor:
- # Compute the full input spec, including state
- input_spec = self.input_spec
- state_spec = self.state_spec
- if not isinstance(input_spec, list):
- input_spec = [input_spec]
- if not isinstance(state_spec, list):
- state_spec = [state_spec]
- self.input_spec = input_spec + state_spec
-
- # Compute the full inputs, including state
- inputs = [inputs] + list(initial_state)
-
- # Perform the call
- output = super(Recurrent, self).__call__(inputs, **kwargs)
-
- # Restore original input spec
- self.input_spec = input_spec
- return output
- else:
- kwargs['initial_state'] = initial_state
- return super(Recurrent, self).__call__(inputs, **kwargs)
-
- def call(self, inputs, mask=None, training=None, initial_state=None):
- # input shape: `(samples, time (padded with zeros), input_dim)`
- # note that the .build() method of subclasses MUST define
- # self.input_spec and self.state_spec with complete input shapes.
- if isinstance(inputs, list):
- initial_state = inputs[1:]
- inputs = inputs[0]
- elif initial_state is not None:
- pass
- elif self.stateful:
- initial_state = self.states
- else:
- initial_state = self.get_initial_state(inputs)
-
- if isinstance(mask, list):
- mask = mask[0]
-
- if len(initial_state) != len(self.states):
- raise ValueError('Layer has ' + str(len(self.states)) +
- ' states but was passed ' + str(len(initial_state)) +
- ' initial states.')
- input_shape = K.int_shape(inputs)
- if self.unroll and input_shape[1] is None:
- raise ValueError('Cannot unroll a RNN if the '
- 'time dimension is undefined. \n'
- '- If using a Sequential model, '
- 'specify the time dimension by passing '
- 'an `input_shape` or `batch_input_shape` '
- 'argument to your first layer. If your '
- 'first layer is an Embedding, you can '
- 'also use the `input_length` argument.\n'
- '- If using the functional API, specify '
- 'the time dimension by passing a `shape` '
- 'or `batch_shape` argument to your Input layer.')
- constants = self.get_constants(inputs, training=None)
- preprocessed_input = self.preprocess_input(inputs, training=None)
- last_output, outputs, states = K.rnn(
- self.step,
- preprocessed_input,
- initial_state,
- go_backwards=self.go_backwards,
- mask=mask,
- constants=constants,
- unroll=self.unroll)
- if self.stateful:
- updates = []
- for i in range(len(states)):
- updates.append(state_ops.assign(self.states[i], states[i]))
- self.add_update(updates, inputs)
-
- # Properly set learning phase
- if 0 < self.dropout + self.recurrent_dropout:
- last_output._uses_learning_phase = True
- outputs._uses_learning_phase = True
-
- if not self.return_sequences:
- outputs = last_output
-
- if self.return_state:
- if not isinstance(states, (list, tuple)):
- states = [states]
- else:
- states = list(states)
- return [outputs] + states
- return outputs
-
- def reset_states(self, states=None):
- if not self.stateful:
- raise AttributeError('Layer must be stateful.')
- batch_size = self.input_spec[0].shape[0]
- if not batch_size:
- raise ValueError('If a RNN is stateful, it needs to know '
- 'its batch size. Specify the batch size '
- 'of your input tensors: \n'
- '- If using a Sequential model, '
- 'specify the batch size by passing '
- 'a `batch_input_shape` '
- 'argument to your first layer.\n'
- '- If using the functional API, specify '
- 'the time dimension by passing a '
- '`batch_shape` argument to your Input layer.')
- # initialize state if None
- if self.states[0] is None:
- self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
- elif states is None:
- for state in self.states:
- K.set_value(state, np.zeros((batch_size, self.units)))
- else:
- if not isinstance(states, (list, tuple)):
- states = [states]
- if len(states) != len(self.states):
- raise ValueError('Layer ' + self.name + ' expects ' +
- str(len(self.states)) + ' states, '
- 'but it received ' + str(len(states)) +
- ' state values. Input received: ' + str(states))
- for index, (value, state) in enumerate(zip(states, self.states)):
- if value.shape != (batch_size, self.units):
- raise ValueError('State ' + str(index) +
- ' is incompatible with layer ' + self.name +
- ': expected shape=' + str((batch_size, self.units)) +
- ', found shape=' + str(value.shape))
- K.set_value(state, value)
-
- def get_config(self):
- config = {
- 'return_sequences': self.return_sequences,
- 'return_state': self.return_state,
- 'go_backwards': self.go_backwards,
- 'stateful': self.stateful,
- 'unroll': self.unroll,
- 'implementation': self.implementation
- }
- base_config = super(Recurrent, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
-
-
def _standardize_args(inputs, initial_state, constants, num_constants):
"""Standardizes `__call__` to a single list of tensor inputs.
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index 18fefbe84f..1429537648 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -183,6 +183,7 @@ class SimpleRNNLayerTest(test.TestCase):
self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
self.assertEqual(layer.cell.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
def test_with_masking_layer_SimpleRNN(self):
layer_class = keras.layers.SimpleRNN
with self.test_session():
@@ -192,7 +193,8 @@ class SimpleRNNLayerTest(test.TestCase):
model = keras.models.Sequential()
model.add(keras.layers.Masking(input_shape=(3, 4)))
model.add(layer_class(units=5, return_sequences=True, unroll=False))
- model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.compile(loss='categorical_crossentropy',
+ optimizer=RMSPropOptimizer(0.01))
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1)
def test_from_config_SimpleRNN(self):
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 7d8b1fec45..b18f12612a 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -141,7 +141,7 @@ def result_wrapper(result_fn):
return tf_decorator.make_decorator(result_fn, decorated)
-def _safe_div(numerator, denominator):
+def safe_div(numerator, denominator):
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.
Args:
@@ -158,7 +158,7 @@ def _safe_div(numerator, denominator):
return array_ops.where(condition, t, zero)
-def _squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
+def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
"""Squeeze or expand last dimension if needed.
1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
@@ -275,7 +275,7 @@ class Metric(Layer):
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = math_ops.cast(y_true, dtypes.bool)
y_pred = math_ops.cast(y_pred, dtypes.bool)
- y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
y_pred, y_true, sample_weight)
values = math_ops.logical_and(
@@ -420,11 +420,20 @@ class Mean(Metric):
else:
sample_weight = math_ops.cast(sample_weight, self._dtype)
- # Update dimensions of weights to match with values.
- values, _, sample_weight = _squeeze_or_expand_dimensions(
+ # Update dimensions of weights to match with values if possible.
+ values, _, sample_weight = squeeze_or_expand_dimensions(
values, None, sample_weight)
- sample_weight = weights_broadcast_ops.broadcast_weights(
- sample_weight, values)
+ try:
+ # Broadcast weights if possible.
+ sample_weight = weights_broadcast_ops.broadcast_weights(
+ sample_weight, values)
+ except ValueError:
+ # Reduce values to same ndim as weight array
+ ndim = K.ndim(values)
+ weight_ndim = K.ndim(sample_weight)
+ values = math_ops.reduce_mean(
+ values, axis=list(range(weight_ndim, ndim)))
+
num_values = math_ops.reduce_sum(sample_weight)
values = math_ops.multiply(values, sample_weight)
values = math_ops.reduce_sum(values)
@@ -434,7 +443,7 @@ class Mean(Metric):
state_ops.assign_add(self.count, num_values)
def result(self):
- return _safe_div(self.total, self.count)
+ return safe_div(self.total, self.count)
class MeanMetricWrapper(Mean):
@@ -468,7 +477,7 @@ class MeanMetricWrapper(Mean):
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
- y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
y_pred, y_true, sample_weight)
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index d583379708..49f3ae40d9 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -258,6 +258,13 @@ class KerasMetricsTest(test.TestCase):
self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1
self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2
+ # check values reduced to the dimensions of weight
+ result_t = m([[[1., 2.], [3., 2.], [0.5, 4.]]], sample_weight=[0.5])
+ result = np.round(self.evaluate(result_t), decimals=2) # 58.5 / 5.6
+ self.assertEqual(result, 10.45)
+ self.assertEqual(np.round(self.evaluate(m.total), decimals=2), 58.54)
+ self.assertEqual(np.round(self.evaluate(m.count), decimals=2), 5.6)
+
def test_mean_graph_with_placeholder(self):
with context.graph_mode(), self.test_session() as sess:
m = metrics.Mean()
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 3a153573f8..6cbea45bd5 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -189,6 +189,27 @@ def get_nested_model_3(input_dim, num_classes):
class ModelSubclassingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
+ def test_custom_build(self):
+ class DummyModel(keras.Model):
+
+ def __init__(self):
+ super(DummyModel, self).__init__()
+ self.dense1 = keras.layers.Dense(32, activation='relu')
+ self.uses_custom_build = False
+
+ def call(self, inputs):
+ return self.dense1(inputs)
+
+ def build(self, input_shape):
+ self.uses_custom_build = True
+
+ test_model = DummyModel()
+ dummy_data = array_ops.ones((32, 50))
+ test_model(dummy_data)
+ self.assertTrue(test_model.uses_custom_build, 'Model should use user '
+ 'defined build when called.')
+
+ @test_util.run_in_graph_and_eager_modes
def test_invalid_input_shape_build(self):
num_classes = 2
input_dim = 50
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 21217fdca1..0bd6620220 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -26,7 +26,6 @@ from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.utils import generic_utils
-from tensorflow.python.keras.utils.generic_utils import has_arg
# API entries importable from `keras.models`:
@@ -69,7 +68,7 @@ def _clone_functional_model(model, input_tensors=None):
'got a `Sequential` instance instead:', model)
layer_map = {} # Cache for created layers.
- tensor_map = {} # Map {reference_tensor: (corresponding_tensor, mask)}
+ tensor_map = {} # Map {reference_tensor: corresponding_tensor}
if input_tensors is None:
# Create placeholders to build the model on top of.
input_layers = []
@@ -106,7 +105,7 @@ def _clone_functional_model(model, input_tensors=None):
input_tensors = input_tensors_
for x, y in zip(model.inputs, input_tensors):
- tensor_map[x] = (y, None) # tensor, mask
+ tensor_map[x] = y
# Iterated over every node in the reference model, in depth order.
depth_keys = list(model._nodes_by_depth.keys())
@@ -131,55 +130,41 @@ def _clone_functional_model(model, input_tensors=None):
continue
# Gather inputs to call the new layer.
- referenceinput_tensors_ = node.input_tensors
+ reference_input_tensors = node.input_tensors
reference_output_tensors = node.output_tensors
# If all previous input tensors are available in tensor_map,
# then call node.inbound_layer on them.
- computed_data = [] # List of tuples (input, mask).
- for x in referenceinput_tensors_:
+ computed_tensors = []
+ for x in reference_input_tensors:
if x in tensor_map:
- computed_data.append(tensor_map[x])
+ computed_tensors.append(tensor_map[x])
- if len(computed_data) == len(referenceinput_tensors_):
+ if len(computed_tensors) == len(reference_input_tensors):
# Call layer.
if node.arguments:
kwargs = node.arguments
else:
kwargs = {}
- if len(computed_data) == 1:
- computed_tensor, computed_mask = computed_data[0]
- if has_arg(layer.call, 'mask'):
- if 'mask' not in kwargs:
- kwargs['mask'] = computed_mask
+ if len(computed_tensors) == 1:
+ computed_tensor = computed_tensors[0]
output_tensors = generic_utils.to_list(layer(computed_tensor,
**kwargs))
- output_masks = generic_utils.to_list(
- layer.compute_mask(computed_tensor, computed_mask))
computed_tensors = [computed_tensor]
- computed_masks = [computed_mask]
else:
- computed_tensors = [x[0] for x in computed_data]
- computed_masks = [x[1] for x in computed_data]
- if has_arg(layer.call, 'mask'):
- if 'mask' not in kwargs:
- kwargs['mask'] = computed_masks
+ computed_tensors = computed_tensors
output_tensors = generic_utils.to_list(layer(computed_tensors,
**kwargs))
- output_masks = generic_utils.to_list(
- layer.compute_mask(computed_tensors, computed_masks))
- # Update tensor_map.
- for x, y, mask in zip(reference_output_tensors, output_tensors,
- output_masks):
- tensor_map[x] = (y, mask)
+
+ for x, y in zip(reference_output_tensors, output_tensors):
+ tensor_map[x] = y
# Check that we did compute the model outputs,
# then instantiate a new model from inputs and outputs.
output_tensors = []
for x in model.outputs:
assert x in tensor_map, 'Could not compute output ' + str(x)
- tensor, _ = tensor_map[x]
- output_tensors.append(tensor)
+ output_tensors.append(tensor_map[x])
return Model(input_tensors, output_tensors, name=model.name)
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 1525104ac9..1385ad5390 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -115,6 +115,22 @@ class TestModelCloning(test.TestCase):
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)
+ @test_util.run_in_graph_and_eager_modes
+ def test_clone_functional_model_with_masking(self):
+ with self.test_session():
+ x = np.array([[[1], [1]], [[0], [0]]])
+ inputs = keras.Input((2, 1))
+ outputs = keras.layers.Masking(mask_value=0)(inputs)
+ outputs = keras.layers.TimeDistributed(
+ keras.layers.Dense(1, kernel_initializer='one'))(outputs)
+ model = keras.Model(inputs, outputs)
+
+ model = keras.models.clone_model(model)
+ model.compile(loss='mse', optimizer=adam.AdamOptimizer(0.01))
+ y = np.array([[[1], [1]], [[1], [1]]])
+ loss = model.train_on_batch(x, y)
+ self.assertEqual(float(loss), 0.)
+
def test_model_cloning_invalid_use_cases(self):
seq_model = keras.models.Sequential()
seq_model.add(keras.layers.Dense(4, input_shape=(4,)))
diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py
index a69893955f..2e56fa2dc5 100644
--- a/tensorflow/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/utils/generic_utils.py
@@ -162,7 +162,7 @@ def deserialize_keras_object(identifier,
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
if hasattr(cls, 'from_config'):
- arg_spec = tf_inspect.getargspec(cls.from_config)
+ arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
@@ -281,8 +281,8 @@ def has_arg(fn, name, accept_all=False):
Returns:
bool, whether `fn` accepts a `name` keyword argument.
"""
- arg_spec = tf_inspect.getargspec(fn)
- if accept_all and arg_spec.keywords is not None:
+ arg_spec = tf_inspect.getfullargspec(fn)
+ if accept_all and arg_spec.varkw is not None:
return True
return name in arg_spec.args
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index edb0910354..2451dc7257 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -566,6 +566,7 @@ tf_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:linalg_ops",
],
+ shard_count = 16,
)
tf_py_test(
@@ -701,7 +702,7 @@ tf_py_test(
tf_py_test(
name = "priority_queue_test",
- size = "small",
+ size = "medium",
srcs = ["priority_queue_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 6f401358a2..0e4e58409e 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.platform import test as test_lib
@@ -173,6 +174,10 @@ if __name__ == '__main__':
_AddTest(MatrixUnaryFunctorGradientTest, 'MatrixInverseGradient', name,
_GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_inverse,
dtype, shape))
+ _AddTest(MatrixUnaryFunctorGradientTest, 'MatrixExponentialGradient',
+ name,
+ _GetMatrixUnaryFunctorGradientTest(
+ linalg_impl.matrix_exponential, dtype, shape))
_AddTest(
MatrixUnaryFunctorGradientTest, 'MatrixDeterminantGradient', name,
_GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_determinant,
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
index a0c66c77d8..0386e91276 100644
--- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -12,33 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tensorflow.ops.gen_linalg_ops.matrix_exponential."""
+"""Tests for tensorflow.ops.linalg.linalg_impl.matrix_exponential."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
-import math
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.platform import test
-def np_expm(x):
+def np_expm(x): # pylint: disable=invalid-name
"""Slow but accurate Taylor series matrix exponential."""
y = np.zeros(x.shape, dtype=x.dtype)
xn = np.eye(x.shape[0], dtype=x.dtype)
for n in range(40):
- y += xn / float(math.factorial(n))
+ if n > 0:
+ xn /= float(n)
+ y += xn
xn = np.dot(xn, x)
return y
@@ -48,7 +50,7 @@ class ExponentialOpTest(test.TestCase):
def _verifyExponential(self, x, np_type):
inp = x.astype(np_type)
with self.test_session(use_gpu=True):
- tf_ans = gen_linalg_ops.matrix_exponential(inp)
+ tf_ans = linalg_impl.matrix_exponential(inp)
if x.size == 0:
np_ans = np.empty(x.shape, dtype=np_type)
else:
@@ -76,7 +78,7 @@ class ExponentialOpTest(test.TestCase):
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
return matrix_batch
- def testNonsymmetric(self):
+ def testNonsymmetricReal(self):
# 2x2 matrices
matrix1 = np.array([[1., 2.], [3., 4.]])
matrix2 = np.array([[1., 3.], [3., 5.]])
@@ -84,7 +86,10 @@ class ExponentialOpTest(test.TestCase):
self._verifyExponentialReal(matrix2)
# A multidimensional batch of 2x2 matrices
self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
- # Complex
+
+ def testNonsymmetricComplex(self):
+ matrix1 = np.array([[1., 2.], [3., 4.]])
+ matrix2 = np.array([[1., 3.], [3., 5.]])
matrix1 = matrix1.astype(np.complex64)
matrix1 += 1j * matrix1
matrix2 = matrix2.astype(np.complex64)
@@ -94,7 +99,7 @@ class ExponentialOpTest(test.TestCase):
# Complex batch
self._verifyExponentialComplex(self._makeBatch(matrix1, matrix2))
- def testSymmetricPositiveDefinite(self):
+ def testSymmetricPositiveDefiniteReal(self):
# 2x2 matrices
matrix1 = np.array([[2., 1.], [1., 2.]])
matrix2 = np.array([[3., -1.], [-1., 3.]])
@@ -102,7 +107,10 @@ class ExponentialOpTest(test.TestCase):
self._verifyExponentialReal(matrix2)
# A multidimensional batch of 2x2 matrices
self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
- # Complex
+
+ def testSymmetricPositiveDefiniteComplex(self):
+ matrix1 = np.array([[2., 1.], [1., 2.]])
+ matrix2 = np.array([[3., -1.], [-1., 3.]])
matrix1 = matrix1.astype(np.complex64)
matrix1 += 1j * matrix1
matrix2 = matrix2.astype(np.complex64)
@@ -116,35 +124,31 @@ class ExponentialOpTest(test.TestCase):
# When the exponential of a non-square matrix is attempted we should return
# an error
with self.assertRaises(ValueError):
- gen_linalg_ops.matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]]))
+ linalg_impl.matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]]))
def testWrongDimensions(self):
# The input to the exponential should be at least a 2-dimensional tensor.
tensor3 = constant_op.constant([1., 2.])
with self.assertRaises(ValueError):
- gen_linalg_ops.matrix_exponential(tensor3)
+ linalg_impl.matrix_exponential(tensor3)
def testEmpty(self):
self._verifyExponentialReal(np.empty([0, 2, 2]))
self._verifyExponentialReal(np.empty([2, 0, 0]))
- def testRandomSmallAndLarge(self):
- np.random.seed(42)
- for dtype in np.float32, np.float64, np.complex64, np.complex128:
- for batch_dims in [(), (1,), (3,), (2, 2)]:
- for size in 8, 31, 32:
- shape = batch_dims + (size, size)
- matrix = np.random.uniform(
- low=-1.0, high=1.0,
- size=np.prod(shape)).reshape(shape).astype(dtype)
- self._verifyExponentialReal(matrix)
+ def testDynamic(self):
+ with self.test_session(use_gpu=True) as sess:
+ inp = array_ops.placeholder(ops.dtypes.float32)
+ expm = linalg_impl.matrix_exponential(inp)
+ matrix = np.array([[1., 2.], [3., 4.]])
+ sess.run(expm, feed_dict={inp: matrix})
def testConcurrentExecutesWithoutError(self):
with self.test_session(use_gpu=True) as sess:
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
- expm1 = gen_linalg_ops.matrix_exponential(matrix1)
- expm2 = gen_linalg_ops.matrix_exponential(matrix2)
+ expm1 = linalg_impl.matrix_exponential(matrix1)
+ expm2 = linalg_impl.matrix_exponential(matrix2)
expm = sess.run([expm1, expm2])
self.assertAllEqual(expm[0], expm[1])
@@ -180,7 +184,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
session.Session() as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
- expm = gen_linalg_ops.matrix_exponential(matrix)
+ expm = linalg_impl.matrix_exponential(matrix)
variables.global_variables_initializer().run()
self.run_op_benchmark(
sess,
@@ -189,6 +193,66 @@ class MatrixExponentialBenchmark(test.Benchmark):
name="matrix_exponential_cpu_{shape}".format(
shape=shape))
+ if test.is_gpu_available(True):
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/gpu:0"):
+ matrix = self._GenerateMatrix(shape)
+ expm = linalg_impl.matrix_exponential(matrix)
+ variables.global_variables_initializer().run()
+ self.run_op_benchmark(
+ sess,
+ control_flow_ops.group(expm),
+ min_iters=25,
+ name="matrix_exponential_gpu_{shape}".format(
+ shape=shape))
+
+
+def _TestRandomSmall(dtype, batch_dims, size):
+
+ def Test(self):
+ np.random.seed(42)
+ shape = batch_dims + (size, size)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=shape).astype(dtype)
+ self._verifyExponentialReal(matrix)
+
+ return Test
+
+
+def _TestL1Norms(dtype, shape, scale):
+
+ def Test(self):
+ np.random.seed(42)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(shape)).reshape(shape).astype(dtype)
+ print(dtype, shape, scale, matrix)
+ l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim-2))
+ matrix /= l1_norm
+ self._verifyExponentialReal(scale * matrix)
+
+ return Test
+
if __name__ == "__main__":
+ for dtype_ in [np.float32, np.float64, np.complex64, np.complex128]:
+ for batch_ in [(), (2,), (2, 2)]:
+ for size_ in [4, 7]:
+ name = "%s_%d_%d" % (dtype_.__name__, len(batch_), size_)
+ setattr(ExponentialOpTest, "testL1Norms_" + name,
+ _TestRandomSmall(dtype_, batch_, size_))
+
+ for shape_ in [(3, 3), (2, 3, 3)]:
+ for dtype_ in [np.float32, np.complex64]:
+ for scale_ in [0.1, 1.5, 5.0, 20.0]:
+ name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*10))
+ setattr(ExponentialOpTest, "testL1Norms_" + name,
+ _TestL1Norms(dtype_, shape_, scale_))
+ for dtype_ in [np.float64, np.complex128]:
+ for scale_ in [0.01, 0.2, 0.5, 1.5, 6.0, 25.0]:
+ name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*100))
+ setattr(ExponentialOpTest, "testL1Norms_" + name,
+ _TestL1Norms(dtype_, shape_, scale_))
test.main()
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 402f67619b..4a1fc1d9a9 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -283,7 +283,7 @@ class SliceTest(test.TestCase):
# unintended behavior is prevented.
c = constant_op.constant(5.0)
with self.assertRaisesWithPredicateMatch(
- TypeError, lambda e: "Tensor objects are not iterable" in str(e)):
+ TypeError, lambda e: "Tensor objects are only iterable" in str(e)):
for _ in c:
pass
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index ec1ba7b8f7..5765b17594 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -136,6 +136,33 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
return Status::OK();
}
+Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
+ PyObject** ptr_owner) {
+ *ptr_owner = nullptr;
+ if (!PyUnicode_Check(obj)) {
+ char* buf;
+ if (PyBytes_AsStringAndSize(obj, &buf, len) != 0) {
+ return errors::Internal("Unable to get element as bytes.");
+ }
+ *ptr = buf;
+ return Status::OK();
+ }
+#if (PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 3))
+ *ptr = PyUnicode_AsUTF8AndSize(obj, len);
+ if (*ptr != nullptr) return Status::OK();
+#else
+ PyObject* utemp = PyUnicode_AsUTF8String(obj);
+ char* buf;
+ if (utemp != nullptr && PyBytes_AsStringAndSize(utemp, &buf, len) != -1) {
+ *ptr = buf;
+ *ptr_owner = utemp;
+ return Status::OK();
+ }
+ Py_XDECREF(utemp);
+#endif
+ return errors::Internal("Unable to convert element to UTF-8.");
+}
+
// Iterate over the string array 'array', extract the ptr and len of each string
// element and call f(ptr, len).
template <typename F>
@@ -148,33 +175,12 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) {
if (!item) {
return errors::Internal("Unable to get element from the feed - no item.");
}
- char* ptr;
Py_ssize_t len;
-
- if (PyUnicode_Check(item.get())) {
-#if PY_VERSION_HEX >= 0x03030000
- // Accept unicode by converting to UTF-8 bytes.
- ptr = PyUnicode_AsUTF8AndSize(item.get(), &len);
- if (!ptr) {
- return errors::Internal("Unable to get element as UTF-8.");
- }
- f(ptr, len);
-#else
- PyObject* utemp = PyUnicode_AsUTF8String(item.get());
- if (!utemp || PyBytes_AsStringAndSize(utemp, &ptr, &len) == -1) {
- Py_XDECREF(utemp);
- return errors::Internal("Unable to convert element to UTF-8.");
- }
- f(ptr, len);
- Py_DECREF(utemp);
-#endif
- } else {
- int success = PyBytes_AsStringAndSize(item.get(), &ptr, &len);
- if (success != 0) {
- return errors::Internal("Unable to get element as bytes.");
- }
- f(ptr, len);
- }
+ const char* ptr;
+ PyObject* ptr_owner;
+ TF_RETURN_IF_ERROR(PyObjectToString(item.get(), &ptr, &len, &ptr_owner));
+ f(ptr, len);
+ Py_XDECREF(ptr_owner);
PyArray_ITER_NEXT(iter.get());
}
return Status::OK();
@@ -186,10 +192,11 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
size_t* size, void** buffer) {
// Compute bytes needed for encoding.
*size = 0;
- TF_RETURN_IF_ERROR(PyBytesArrayMap(array, [&size](char* ptr, Py_ssize_t len) {
- *size +=
- sizeof(tensorflow::uint64) + tensorflow::core::VarintLength(len) + len;
- }));
+ TF_RETURN_IF_ERROR(
+ PyBytesArrayMap(array, [&size](const char* ptr, Py_ssize_t len) {
+ *size += sizeof(tensorflow::uint64) +
+ tensorflow::core::VarintLength(len) + len;
+ }));
// Encode all strings.
std::unique_ptr<char[]> base_ptr(new char[*size]);
char* base = base_ptr.get();
@@ -198,7 +205,7 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
TF_RETURN_IF_ERROR(PyBytesArrayMap(
- array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) {
+ array, [&data_start, &dst, &offsets](const char* ptr, Py_ssize_t len) {
*offsets = (dst - data_start);
offsets++;
dst = tensorflow::core::EncodeVarint64(dst, len);
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 57139986af..7c107138be 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -333,6 +333,35 @@ class NumpyTensorBuffer : public TensorBuffer {
void* data_;
};
+Status PyObjectToString(PyObject* obj, string* str) {
+ char* py_bytes;
+ Py_ssize_t size;
+ if (PyBytes_AsStringAndSize(obj, &py_bytes, &size) != -1) {
+ str->assign(py_bytes, size);
+ return Status::OK();
+ }
+#if PY_MAJOR_VERSION >= 3
+ const char* ptr = PyUnicode_AsUTF8AndSize(obj, &size);
+ if (ptr != nullptr) {
+ str->assign(ptr, size);
+ return Status::OK();
+ }
+#else
+ if (PyUnicode_Check(obj)) {
+ PyObject* unicode = PyUnicode_AsUTF8String(obj);
+ char* ptr;
+ if (unicode && PyString_AsStringAndSize(unicode, &ptr, &size) != -1) {
+ str->assign(ptr, size);
+ Py_DECREF(unicode);
+ return Status::OK();
+ }
+ Py_XDECREF(unicode);
+ }
+#endif
+ return errors::Unimplemented("Unsupported object type ",
+ obj->ob_type->tp_name);
+}
+
Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
PyArrayObject* input = reinterpret_cast<PyArrayObject*>(obj);
DataType dtype = DT_INVALID;
@@ -348,29 +377,7 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
auto tflat = t.flat<string>();
PyObject** input_data = reinterpret_cast<PyObject**>(PyArray_DATA(input));
for (int i = 0; i < tflat.dimension(0); ++i) {
- char* el;
- Py_ssize_t el_size;
- if (PyBytes_AsStringAndSize(input_data[i], &el, &el_size) == -1) {
-#if PY_MAJOR_VERSION >= 3
- el = PyUnicode_AsUTF8AndSize(input_data[i], &el_size);
-#else
- el = nullptr;
- if (PyUnicode_Check(input_data[i])) {
- PyObject* unicode = PyUnicode_AsUTF8String(input_data[i]);
- if (unicode) {
- if (PyString_AsStringAndSize(unicode, &el, &el_size) == -1) {
- Py_DECREF(unicode);
- el = nullptr;
- }
- }
- }
-#endif
- if (!el) {
- return errors::Unimplemented("Unsupported object type ",
- input_data[i]->ob_type->tp_name);
- }
- }
- tflat(i) = string(el, el_size);
+ TF_RETURN_IF_ERROR(PyObjectToString(input_data[i], &tflat(i)));
}
*ret = t;
break;
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index ba749da47a..ae76fcceba 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -47,6 +47,9 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
}
PyRecordWriter::~PyRecordWriter() {
+ // Writer depends on file during close for zlib flush, so destruct first.
+ writer_.reset();
+ file_.reset();
}
bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) {
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index dcc1a25f42..4a13e8c428 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -318,5 +318,40 @@ class TFRecordIteratorTest(TFCompressionTestCase):
for _ in tf_record.tf_record_iterator(fn_truncated):
pass
+
+class TFRecordWriterCloseAndFlushTests(test.TestCase):
+
+ def setUp(self, compression_type=TFRecordCompressionType.NONE):
+ super(TFRecordWriterCloseAndFlushTests, self).setUp()
+ self._fn = os.path.join(self.get_temp_dir(), "tf_record_writer_test.txt")
+ self._options = tf_record.TFRecordOptions(compression_type)
+ self._writer = tf_record.TFRecordWriter(self._fn, self._options)
+ self._num_records = 20
+
+ def _Record(self, r):
+ return compat.as_bytes("Record %d" % r)
+
+ def testWriteAndLeaveOpen(self):
+ records = list(map(self._Record, range(self._num_records)))
+ for record in records:
+ self._writer.write(record)
+
+ # Verify no segfault if writer isn't explicitly closed.
+
+
+class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests):
+
+ def setUp(self):
+ super(TFRecordWriterCloseAndFlushGzipTests,
+ self).setUp(TFRecordCompressionType.GZIP)
+
+
+class TFRecordWriterCloseAndFlushZlibTests(TFRecordWriterCloseAndFlushTests):
+
+ def setUp(self):
+ super(TFRecordWriterCloseAndFlushZlibTests,
+ self).setUp(TFRecordCompressionType.ZLIB)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index aeac61c005..c7061b36dd 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -817,11 +817,12 @@ class GradLoopState(object):
outer_forward_ctxt = forward_ctxt.outer_context
# Add the forward loop counter.
- if outer_forward_ctxt:
- outer_forward_ctxt.Enter()
- cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
- if outer_forward_ctxt:
- outer_forward_ctxt.Exit()
+ with forward_ctxt._graph.as_default(): # pylint: disable=protected-access
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Enter()
+ cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Exit()
self._forward_context = forward_ctxt
self._forward_index = forward_index
@@ -984,60 +985,61 @@ class GradLoopState(object):
for the stack can't be found.
"""
# curr_ctxt is the context that tf.gradients was called in.
- curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
- with ops.control_dependencies(None):
- if curr_ctxt:
- curr_ctxt.Enter()
- with ops.colocate_with(value):
- # We only need to pass maximum_iterations to the stack if
- # we're inside an XLA context.
- if not util.IsInXLAContext(value.op):
- max_size = constant_op.constant(-1, dtypes.int32)
- else:
- max_size = GetMaxSizeFromNestedMaximumIterations(
- value, self.forward_context)
- acc = gen_data_flow_ops.stack_v2(
- max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
- if curr_ctxt:
- curr_ctxt.Exit()
-
- # Make acc available in the forward context.
- enter_acc = self.forward_context.AddValue(acc)
-
- # Add the stack_push op in the context of value.op.
- swap_enabled = self.forward_context.swap_memory
- value_ctxt = util.GetOutputContext(value.op)
- if value_ctxt == self.forward_context:
- # value is not nested in the forward context.
- self.forward_context.Enter()
- push = gen_data_flow_ops.stack_push_v2(
- enter_acc, value, swap_memory=swap_enabled)
- self.forward_context.Exit()
- # Protect stack push and order it before forward_index.
- self.forward_index.op._add_control_input(push.op)
- else:
- # value is in a cond context within the forward context.
- if not isinstance(value_ctxt, CondContext):
- raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
- if dead_branch:
- # The special case for creating a zero tensor for a dead
- # branch of a switch. See ControlFlowState.ZerosLike().
- value_ctxt.outer_context.Enter()
+ with self._forward_index.graph.as_default():
+ curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ with ops.control_dependencies(None):
+ if curr_ctxt:
+ curr_ctxt.Enter()
+ with ops.colocate_with(value):
+ # We only need to pass maximum_iterations to the stack if
+ # we're inside an XLA context.
+ if not util.IsInXLAContext(value.op):
+ max_size = constant_op.constant(-1, dtypes.int32)
+ else:
+ max_size = GetMaxSizeFromNestedMaximumIterations(
+ value, self.forward_context)
+ acc = gen_data_flow_ops.stack_v2(
+ max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
+ if curr_ctxt:
+ curr_ctxt.Exit()
+
+ # Make acc available in the forward context.
+ enter_acc = self.forward_context.AddValue(acc)
+
+ # Add the stack_push op in the context of value.op.
+ swap_enabled = self.forward_context.swap_memory
+ value_ctxt = util.GetOutputContext(value.op)
+ if value_ctxt == self.forward_context:
+ # value is not nested in the forward context.
+ self.forward_context.Enter()
push = gen_data_flow_ops.stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
- value_ctxt.outer_context.Exit()
- push.op._set_control_flow_context(value_ctxt)
+ self.forward_context.Exit()
+ # Protect stack push and order it before forward_index.
+ self.forward_index.op._add_control_input(push.op)
else:
- value_ctxt.Enter()
- push = gen_data_flow_ops.stack_push_v2(
- enter_acc, value, swap_memory=swap_enabled)
- value_ctxt.Exit()
- # Protect stack push and order it before forward_sync.
- self.forward_sync._add_control_input(push.op)
- # Order stack push after the successor of forward_index
- add_op = self.forward_index.op.inputs[0].op
- push.op._add_control_input(add_op)
- return acc
+ # value is in a cond context within the forward context.
+ if not isinstance(value_ctxt, CondContext):
+ raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
+ if dead_branch:
+ # The special case for creating a zero tensor for a dead
+ # branch of a switch. See ControlFlowState.ZerosLike().
+ value_ctxt.outer_context.Enter()
+ push = gen_data_flow_ops.stack_push_v2(
+ enter_acc, value, swap_memory=swap_enabled)
+ value_ctxt.outer_context.Exit()
+ push.op._set_control_flow_context(value_ctxt)
+ else:
+ value_ctxt.Enter()
+ push = gen_data_flow_ops.stack_push_v2(
+ enter_acc, value, swap_memory=swap_enabled)
+ value_ctxt.Exit()
+ # Protect stack push and order it before forward_sync.
+ self.forward_sync._add_control_input(push.op)
+ # Order stack push after the successor of forward_index
+ add_op = self.forward_index.op.inputs[0].op
+ push.op._add_control_input(add_op)
+ return acc
def AddBackpropAccumulatedValue(self, history_value, value,
dead_branch=False):
@@ -2215,6 +2217,7 @@ class WhileContext(ControlFlowContext):
self._loop_exits = []
# The list of enter tensors for loop variables.
self._loop_enters = []
+ self._graph = ops.get_default_graph()
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `WhileContext` from protocol buffer.
@@ -2268,6 +2271,7 @@ class WhileContext(ControlFlowContext):
op._set_attr("frame_name",
attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
# pylint: enable=protected-access
+ self._graph = ops.get_default_graph()
@property
def maximum_iterations(self):
@@ -2592,7 +2596,14 @@ class WhileContext(ControlFlowContext):
Returns:
The loop index.
"""
- one = constant_op.constant(1, name="b_count")
+ in_separate_functions = count.graph is not ops.get_default_graph()
+ if in_separate_functions:
+ # Brings the count into this graph
+ count = array_ops.identity(count)
+ else:
+ # TODO(apassos) XLA expects this constant to be created outside the loop,
+ # so doing that for now.
+ one = constant_op.constant(1, name="b_count")
self.Enter()
self.AddName(count.name)
@@ -2607,6 +2618,8 @@ class WhileContext(ControlFlowContext):
merge_count = merge([enter_count, enter_count])[0]
self._pivot_for_pred = merge_count
+ if in_separate_functions:
+ one = constant_op.constant(1, name="b_count")
pred = math_ops.greater_equal(merge_count, one)
self._pivot = loop_cond(pred, name="b_count")
switch_count = switch(merge_count, self._pivot)
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index ca24f11054..9f77a6cca1 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -142,9 +142,9 @@ def _graph_mode_decorator(f, *args, **kwargs):
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = list(set(tape.watched_variables()) - set(args))
- grad_argspec = tf_inspect.getargspec(grad_fn)
+ grad_argspec = tf_inspect.getfullargspec(grad_fn)
variables_in_signature = ("variables" in grad_argspec.args or
- grad_argspec.keywords)
+ grad_argspec.varkw)
if variables and not variables_in_signature:
raise TypeError("If using @custom_gradient with a function that "
"uses variables, then grad_fn must accept a keyword "
@@ -194,9 +194,9 @@ def _eager_mode_decorator(f, *args, **kwargs):
# The variables that grad_fn needs to return gradients for are the set of
# variables used that are *not* part of the inputs.
variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
- grad_argspec = tf_inspect.getargspec(grad_fn)
- if (variables and
- not ("variables" in grad_argspec.args or grad_argspec.keywords)):
+ grad_argspec = tf_inspect.getfullargspec(grad_fn)
+ if (variables and ("variables" not in grad_argspec.args) and
+ not grad_argspec.varkw):
raise TypeError("If using @custom_gradient with a function that "
"uses variables, then grad_fn must accept a keyword "
"argument 'variables'.")
diff --git a/tensorflow/python/ops/linalg/BUILD b/tensorflow/python/ops/linalg/BUILD
index 07659ef44c..c7314d7774 100644
--- a/tensorflow/python/ops/linalg/BUILD
+++ b/tensorflow/python/ops/linalg/BUILD
@@ -29,6 +29,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:special_math_ops",
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index 8343c62816..1e3d817980 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -18,8 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@@ -38,8 +41,6 @@ diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
-expm = gen_linalg_ops.matrix_exponential
-tf_export('linalg.expm')(expm)
eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm
@@ -114,3 +115,214 @@ def adjoint(matrix, name=None):
with ops.name_scope(name, 'adjoint', [matrix]):
matrix = ops.convert_to_tensor(matrix, name='matrix')
return array_ops.matrix_transpose(matrix, conjugate=True)
+
+
+# This section is ported nearly verbatim from Eigen's implementation:
+# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html
+def _matrix_exp_pade3(matrix):
+ """3rd-order Pade approximant for matrix exponential."""
+ b = [120.0, 60.0, 12.0]
+ b = [constant_op.constant(x, matrix.dtype) for x in b]
+ ident = linalg_ops.eye(array_ops.shape(matrix)[-2],
+ batch_shape=array_ops.shape(matrix)[:-2],
+ dtype=matrix.dtype)
+ matrix_2 = math_ops.matmul(matrix, matrix)
+ tmp = matrix_2 + b[1] * ident
+ matrix_u = math_ops.matmul(matrix, tmp)
+ matrix_v = b[2] * matrix_2 + b[0] * ident
+ return matrix_u, matrix_v
+
+
+def _matrix_exp_pade5(matrix):
+ """5th-order Pade approximant for matrix exponential."""
+ b = [30240.0, 15120.0, 3360.0, 420.0, 30.0]
+ b = [constant_op.constant(x, matrix.dtype) for x in b]
+ ident = linalg_ops.eye(array_ops.shape(matrix)[-2],
+ batch_shape=array_ops.shape(matrix)[:-2],
+ dtype=matrix.dtype)
+ matrix_2 = math_ops.matmul(matrix, matrix)
+ matrix_4 = math_ops.matmul(matrix_2, matrix_2)
+ tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident
+ matrix_u = math_ops.matmul(matrix, tmp)
+ matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
+ return matrix_u, matrix_v
+
+
+def _matrix_exp_pade7(matrix):
+ """7th-order Pade approximant for matrix exponential."""
+ b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0]
+ b = [constant_op.constant(x, matrix.dtype) for x in b]
+ ident = linalg_ops.eye(array_ops.shape(matrix)[-2],
+ batch_shape=array_ops.shape(matrix)[:-2],
+ dtype=matrix.dtype)
+ matrix_2 = math_ops.matmul(matrix, matrix)
+ matrix_4 = math_ops.matmul(matrix_2, matrix_2)
+ matrix_6 = math_ops.matmul(matrix_4, matrix_2)
+ tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident
+ matrix_u = math_ops.matmul(matrix, tmp)
+ matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
+ return matrix_u, matrix_v
+
+
+def _matrix_exp_pade9(matrix):
+ """9th-order Pade approximant for matrix exponential."""
+ b = [
+ 17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0,
+ 2162160.0, 110880.0, 3960.0, 90.0
+ ]
+ b = [constant_op.constant(x, matrix.dtype) for x in b]
+ ident = linalg_ops.eye(array_ops.shape(matrix)[-2],
+ batch_shape=array_ops.shape(matrix)[:-2],
+ dtype=matrix.dtype)
+ matrix_2 = math_ops.matmul(matrix, matrix)
+ matrix_4 = math_ops.matmul(matrix_2, matrix_2)
+ matrix_6 = math_ops.matmul(matrix_4, matrix_2)
+ matrix_8 = math_ops.matmul(matrix_6, matrix_2)
+ tmp = (
+ matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 +
+ b[1] * ident)
+ matrix_u = math_ops.matmul(matrix, tmp)
+ matrix_v = (
+ b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 +
+ b[0] * ident)
+ return matrix_u, matrix_v
+
+
+def _matrix_exp_pade13(matrix):
+ """13th-order Pade approximant for matrix exponential."""
+ b = [
+ 64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
+ 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
+ 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0
+ ]
+ b = [constant_op.constant(x, matrix.dtype) for x in b]
+ ident = linalg_ops.eye(array_ops.shape(matrix)[-2],
+ batch_shape=array_ops.shape(matrix)[:-2],
+ dtype=matrix.dtype)
+ matrix_2 = math_ops.matmul(matrix, matrix)
+ matrix_4 = math_ops.matmul(matrix_2, matrix_2)
+ matrix_6 = math_ops.matmul(matrix_4, matrix_2)
+ tmp_u = (
+ math_ops.matmul(matrix_6,
+ matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) +
+ b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident)
+ matrix_u = math_ops.matmul(matrix, tmp_u)
+ tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2
+ matrix_v = (
+ math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 +
+ b[2] * matrix_2 + b[0] * ident)
+ return matrix_u, matrix_v
+
+
+@tf_export('linalg.expm')
+def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
+ r"""Computes the matrix exponential of one or more square matrices.
+
+ exp(A) = \sum_{n=0}^\infty A^n/n!
+
+ The exponential is computed using a combination of the scaling and squaring
+ method and the Pade approximation. Details can be found in:
+ Nicholas J. Higham, "The scaling and squaring method for the matrix
+ exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
+
+ The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+ form square matrices. The output is a tensor of the same shape as the input
+ containing the exponential for all input submatrices `[..., :, :]`.
+
+ Args:
+ input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
+ or `complex128` with shape `[..., M, M]`.
+ name: A name to give this `Op` (optional).
+
+ Returns:
+ the matrix exponential of the input.
+
+ Raises:
+ ValueError: An unsupported type is provided as input.
+
+ @compatibility(scipy)
+ Equivalent to scipy.linalg.expm
+ @end_compatibility
+ """
+ with ops.name_scope(name, 'matrix_exponential', [input]):
+ matrix = ops.convert_to_tensor(input, name='input')
+ if matrix.shape[-2:] == [0, 0]:
+ return matrix
+ batch_shape = matrix.shape[:-2]
+ if not batch_shape.is_fully_defined():
+ batch_shape = array_ops.shape(matrix)[:-2]
+
+ # reshaping the batch makes the where statements work better
+ matrix = array_ops.reshape(
+ matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
+ l1_norm = math_ops.reduce_max(
+ math_ops.reduce_sum(math_ops.abs(matrix),
+ axis=array_ops.size(array_ops.shape(matrix)) - 2),
+ axis=-1)
+ const = lambda x: constant_op.constant(x, l1_norm.dtype)
+ def _nest_where(vals, cases):
+ assert len(vals) == len(cases) - 1
+ if len(vals) == 1:
+ return array_ops.where(
+ math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
+ else:
+ return array_ops.where(
+ math_ops.less(l1_norm, const(vals[0])), cases[0],
+ _nest_where(vals[1:], cases[1:]))
+
+ if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
+ maxnorm = const(3.925724783138660)
+ squarings = math_ops.maximum(
+ math_ops.floor(
+ math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
+ u3, v3 = _matrix_exp_pade3(matrix)
+ u5, v5 = _matrix_exp_pade5(matrix)
+ u7, v7 = _matrix_exp_pade7(
+ matrix / math_ops.pow(
+ constant_op.constant(2.0, dtype=matrix.dtype),
+ math_ops.cast(squarings, matrix.dtype))[...,
+ array_ops.newaxis,
+ array_ops.newaxis])
+ conds = (4.258730016922831e-001, 1.880152677804762e+000)
+ u = _nest_where(conds, (u3, u5, u7))
+ v = _nest_where(conds, (v3, v5, v7))
+ elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
+ maxnorm = const(5.371920351148152)
+ squarings = math_ops.maximum(
+ math_ops.floor(
+ math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
+ u3, v3 = _matrix_exp_pade3(matrix)
+ u5, v5 = _matrix_exp_pade5(matrix)
+ u7, v7 = _matrix_exp_pade7(matrix)
+ u9, v9 = _matrix_exp_pade9(matrix)
+ u13, v13 = _matrix_exp_pade13(
+ matrix / math_ops.pow(
+ constant_op.constant(2.0, dtype=matrix.dtype),
+ math_ops.cast(squarings, matrix.dtype))[...,
+ array_ops.newaxis,
+ array_ops.newaxis])
+ conds = (1.495585217958292e-002,
+ 2.539398330063230e-001,
+ 9.504178996162932e-001,
+ 2.097847961257068e+000)
+ u = _nest_where(conds, (u3, u5, u7, u9, u13))
+ v = _nest_where(conds, (v3, v5, v7, v9, v13))
+ else:
+ raise ValueError(
+ 'tf.linalg.expm does not support matrices of type %s' % matrix.dtype)
+ numer = u + v
+ denom = -u + v
+ result = linalg_ops.matrix_solve(denom, numer)
+ max_squarings = math_ops.reduce_max(squarings)
+
+ i = const(0.0)
+ c = lambda i, r: math_ops.less(i, max_squarings)
+ def b(i, r):
+ return i+1, array_ops.where(math_ops.less(i, squarings),
+ math_ops.matmul(r, r), r)
+ _, result = control_flow_ops.while_loop(c, b, [i, result])
+ if not matrix.shape.is_fully_defined():
+ return array_ops.reshape(
+ result,
+ array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0))
+ return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 8b259b6b6b..d533731c07 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -943,9 +943,10 @@ class ResourceVariable(variables.RefVariable):
if self.trainable:
tape.watch_variable(self)
return _UnreadVariable(
- self._handle, self.dtype, self._shape, self._in_graph_mode,
- self._handle_deleter if not self._in_graph_mode else None, op,
- self._unique_id)
+ handle=self._handle, dtype=self.dtype, shape=self._shape,
+ in_graph_mode=self._in_graph_mode,
+ deleter=self._handle_deleter if not self._in_graph_mode else None,
+ parent_op=op, parent_name=self._handle_name, unique_id=self._unique_id)
def assign(self, value, use_locking=None, name=None, read_value=True):
"""Assigns a new value to this variable.
@@ -1059,7 +1060,8 @@ class _UnreadVariable(ResourceVariable):
"""
def __init__(self, handle, dtype, # pylint: disable=super-init-not-called
- shape, in_graph_mode, deleter, parent_op, unique_id):
+ shape, in_graph_mode, deleter, parent_op, parent_name,
+ unique_id):
# We do not call super init on purpose.
self._trainable = False
self._save_slice_info = None
@@ -1087,7 +1089,10 @@ class _UnreadVariable(ResourceVariable):
@property
def name(self):
- return self._parent_op.name
+ if self._in_graph_mode:
+ return self._parent_op.name
+ else:
+ return "UnreadVariable"
def value(self):
return self._read_variable_op()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 5d7535cf34..1b69e0d06c 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,6 +29,7 @@ limitations under the License.
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
%rename("%s") TFE_ContextSetAsyncForThread;
+%rename("%s") TFE_ContextSetServerDef;
%rename("%s") TFE_ContextAsyncWait;
%rename("%s") TFE_ContextAsyncClearError;
%rename("%s") TFE_OpNameGetAttrType;
@@ -59,7 +60,6 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_ContextOptionsSetAsync;
-%rename("%s") TFE_ContextOptionsSetServerDef;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 4349699a94..130fe70beb 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -55,6 +55,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@@ -78,7 +79,7 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
- not saver_lib.checkpoint_exists(input_checkpoint)):
+ not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
new file mode 100644
index 0000000000..aaddc015ed
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -0,0 +1,406 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=invalid-name
+"""Save and restore variables."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import re
+
+from google.protobuf import text_format
+
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.util.tf_export import tf_export
+
+
+def _GetCheckpointFilename(save_dir, latest_filename):
+ """Returns a filename for storing the CheckpointState.
+
+ Args:
+ save_dir: The directory for saving and restoring checkpoints.
+ latest_filename: Name of the file in 'save_dir' that is used
+ to store the CheckpointState.
+
+ Returns:
+ The path of the file that contains the CheckpointState proto.
+ """
+ if latest_filename is None:
+ latest_filename = "checkpoint"
+ return os.path.join(save_dir, latest_filename)
+
+
+@tf_export("train.generate_checkpoint_state_proto")
+def generate_checkpoint_state_proto(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None):
+ """Generates a checkpoint state proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+
+ Returns:
+ CheckpointState proto with model_checkpoint_path and
+ all_model_checkpoint_paths updated to either absolute paths or
+ relative paths to the current save_dir.
+ """
+ if all_model_checkpoint_paths is None:
+ all_model_checkpoint_paths = []
+
+ if (not all_model_checkpoint_paths or
+ all_model_checkpoint_paths[-1] != model_checkpoint_path):
+ logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
+ model_checkpoint_path)
+ all_model_checkpoint_paths.append(model_checkpoint_path)
+
+ # Relative paths need to be rewritten to be relative to the "save_dir"
+ # if model_checkpoint_path already contains "save_dir".
+ if not os.path.isabs(save_dir):
+ if not os.path.isabs(model_checkpoint_path):
+ model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
+ for i in range(len(all_model_checkpoint_paths)):
+ p = all_model_checkpoint_paths[i]
+ if not os.path.isabs(p):
+ all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
+
+ coord_checkpoint_proto = CheckpointState(
+ model_checkpoint_path=model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths)
+
+ return coord_checkpoint_proto
+
+
+@tf_export("train.update_checkpoint_state")
+def update_checkpoint_state(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None,
+ latest_filename=None):
+ """Updates the content of the 'checkpoint' file.
+
+ This updates the checkpoint file containing a CheckpointState
+ proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Raises:
+ RuntimeError: If any of the model checkpoint paths conflict with the file
+ containing CheckpointSate.
+ """
+ update_checkpoint_state_internal(
+ save_dir=save_dir,
+ model_checkpoint_path=model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths,
+ latest_filename=latest_filename,
+ save_relative_paths=False)
+
+
+def update_checkpoint_state_internal(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None,
+ latest_filename=None,
+ save_relative_paths=False):
+ """Updates the content of the 'checkpoint' file.
+
+ This updates the checkpoint file containing a CheckpointState
+ proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+ save_relative_paths: If `True`, will write relative paths to the checkpoint
+ state file.
+
+ Raises:
+ RuntimeError: If any of the model checkpoint paths conflict with the file
+ containing CheckpointSate.
+ """
+ # Writes the "checkpoint" file for the coordinator for later restoration.
+ coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
+ if save_relative_paths:
+ if os.path.isabs(model_checkpoint_path):
+ rel_model_checkpoint_path = os.path.relpath(
+ model_checkpoint_path, save_dir)
+ else:
+ rel_model_checkpoint_path = model_checkpoint_path
+ rel_all_model_checkpoint_paths = []
+ for p in all_model_checkpoint_paths:
+ if os.path.isabs(p):
+ rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
+ else:
+ rel_all_model_checkpoint_paths.append(p)
+ ckpt = generate_checkpoint_state_proto(
+ save_dir,
+ rel_model_checkpoint_path,
+ all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
+ else:
+ ckpt = generate_checkpoint_state_proto(
+ save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths)
+
+ if coord_checkpoint_filename == ckpt.model_checkpoint_path:
+ raise RuntimeError("Save path '%s' conflicts with path used for "
+ "checkpoint state. Please use a different save path." %
+ model_checkpoint_path)
+
+ # Preventing potential read/write race condition by *atomically* writing to a
+ # file.
+ file_io.atomic_write_string_to_file(coord_checkpoint_filename,
+ text_format.MessageToString(ckpt))
+
+
+@tf_export("train.get_checkpoint_state")
+def get_checkpoint_state(checkpoint_dir, latest_filename=None):
+ """Returns CheckpointState proto from the "checkpoint" file.
+
+ If the "checkpoint" file contains a valid CheckpointState
+ proto, returns it.
+
+ Args:
+ checkpoint_dir: The directory of checkpoints.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Returns:
+ A CheckpointState if the state was available, None
+ otherwise.
+
+ Raises:
+ ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
+ """
+ ckpt = None
+ coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
+ latest_filename)
+ f = None
+ try:
+ # Check that the file exists before opening it to avoid
+ # many lines of errors from colossus in the logs.
+ if file_io.file_exists(coord_checkpoint_filename):
+ file_content = file_io.read_file_to_string(
+ coord_checkpoint_filename)
+ ckpt = CheckpointState()
+ text_format.Merge(file_content, ckpt)
+ if not ckpt.model_checkpoint_path:
+ raise ValueError("Invalid checkpoint state loaded from "
+ + checkpoint_dir)
+ # For relative model_checkpoint_path and all_model_checkpoint_paths,
+ # prepend checkpoint_dir.
+ if not os.path.isabs(ckpt.model_checkpoint_path):
+ ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
+ ckpt.model_checkpoint_path)
+ for i in range(len(ckpt.all_model_checkpoint_paths)):
+ p = ckpt.all_model_checkpoint_paths[i]
+ if not os.path.isabs(p):
+ ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
+ except errors.OpError as e:
+ # It's ok if the file cannot be read
+ logging.warning("%s: %s", type(e).__name__, e)
+ logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
+ return None
+ except text_format.ParseError as e:
+ logging.warning("%s: %s", type(e).__name__, e)
+ logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
+ return None
+ finally:
+ if f:
+ f.close()
+ return ckpt
+
+
+def _prefix_to_checkpoint_path(prefix, format_version):
+ """Returns the pathname of a checkpoint file, given the checkpoint prefix.
+
+ For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
+ returns the pathname to the index file.
+
+ Args:
+ prefix: a string, the prefix of a checkpoint.
+ format_version: the checkpoint format version that corresponds to the
+ prefix.
+ Returns:
+ The pathname of a checkpoint file, taking into account the checkpoint
+ format version.
+ """
+ if format_version == saver_pb2.SaverDef.V2:
+ return prefix + ".index" # The index file identifies a checkpoint.
+ return prefix # Just the data file.
+
+
+@tf_export("train.latest_checkpoint")
+def latest_checkpoint(checkpoint_dir, latest_filename=None):
+ """Finds the filename of latest saved checkpoint file.
+
+ Args:
+ checkpoint_dir: Directory where the variables were saved.
+ latest_filename: Optional name for the protocol buffer file that
+ contains the list of most recent checkpoint filenames.
+ See the corresponding argument to `Saver.save()`.
+
+ Returns:
+ The full path to the latest checkpoint or `None` if no checkpoint was found.
+ """
+ # Pick the latest checkpoint based on checkpoint state.
+ ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
+ if ckpt and ckpt.model_checkpoint_path:
+ # Look for either a V2 path or a V1 path, with priority for V2.
+ v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
+ saver_pb2.SaverDef.V2)
+ v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
+ saver_pb2.SaverDef.V1)
+ if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
+ v1_path):
+ return ckpt.model_checkpoint_path
+ else:
+ logging.error("Couldn't match files for checkpoint %s",
+ ckpt.model_checkpoint_path)
+ return None
+
+
+@tf_export("train.checkpoint_exists")
+def checkpoint_exists(checkpoint_prefix):
+ """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
+
+ This is the recommended way to check if a checkpoint exists, since it takes
+ into account the naming difference between V1 and V2 formats.
+
+ Args:
+ checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ Returns:
+ A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
+ """
+ pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
+ saver_pb2.SaverDef.V2)
+ if file_io.get_matching_files(pathname):
+ return True
+ elif file_io.get_matching_files(checkpoint_prefix):
+ return True
+ else:
+ return False
+
+
+@tf_export("train.get_checkpoint_mtimes")
+def get_checkpoint_mtimes(checkpoint_prefixes):
+ """Returns the mtimes (modification timestamps) of the checkpoints.
+
+ Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
+ exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
+ that priority.
+
+ This is the recommended way to get the mtimes, since it takes into account
+ the naming difference between V1 and V2 formats.
+
+ Args:
+ checkpoint_prefixes: a list of checkpoint paths, typically the results of
+ `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
+ sharded/non-sharded or V1/V2.
+ Returns:
+ A list of mtimes (in microseconds) of the found checkpoints.
+ """
+ mtimes = []
+
+ def match_maybe_append(pathname):
+ fnames = file_io.get_matching_files(pathname)
+ if fnames:
+ mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
+ return True
+ return False
+
+ for checkpoint_prefix in checkpoint_prefixes:
+ # Tries V2's metadata file first.
+ pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
+ saver_pb2.SaverDef.V2)
+ if match_maybe_append(pathname):
+ continue
+ # Otherwise, tries V1, where the prefix is the complete pathname.
+ match_maybe_append(checkpoint_prefix)
+
+ return mtimes
+
+
+@tf_export("train.remove_checkpoint")
+def remove_checkpoint(checkpoint_prefix,
+ checkpoint_format_version=saver_pb2.SaverDef.V2,
+ meta_graph_suffix="meta"):
+ """Removes a checkpoint given by `checkpoint_prefix`.
+
+ Args:
+ checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
+ of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
+ sharded/non-sharded or V1/V2.
+ checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
+ `SaverDef.V2`.
+ meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+ """
+ _delete_file_if_exists(
+ meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
+ if checkpoint_format_version == saver_pb2.SaverDef.V2:
+ # V2 has a metadata file and some data files.
+ _delete_file_if_exists(checkpoint_prefix + ".index")
+ _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
+ else:
+ # V1, Legacy. Exact match on the data file.
+ _delete_file_if_exists(checkpoint_prefix)
+
+
+def _delete_file_if_exists(filespec):
+ """Deletes files matching `filespec`."""
+ for pathname in file_io.get_matching_files(filespec):
+ file_io.delete_file(pathname)
+
+
+def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
+ """Returns the meta graph filename.
+
+ Args:
+ checkpoint_filename: Name of the checkpoint file.
+ meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+
+ Returns:
+ MetaGraph file name.
+ """
+ # If the checkpoint_filename is sharded, the checkpoint_filename could
+ # be of format model.ckpt-step#-?????-of-shard#. For example,
+ # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
+ basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
+ suffixed_filename = ".".join([basename, meta_graph_suffix])
+ return suffixed_filename
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
new file mode 100644
index 0000000000..4b31d0c613
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -0,0 +1,316 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for tensorflow.python.training.saver.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import os
+import shutil
+import tempfile
+
+from google.protobuf import text_format
+
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_module
+from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+
+
+class LatestCheckpointWithRelativePaths(test.TestCase):
+
+ @staticmethod
+ @contextlib.contextmanager
+ def tempWorkingDir(temppath):
+ cwd = os.getcwd()
+ os.chdir(temppath)
+ try:
+ yield
+ finally:
+ os.chdir(cwd)
+
+ @staticmethod
+ @contextlib.contextmanager
+ def tempDir():
+ tempdir = tempfile.mkdtemp()
+ try:
+ yield tempdir
+ finally:
+ shutil.rmtree(tempdir)
+
+ def testNameCollision(self):
+ # Make sure we have a clean directory to work in.
+ with self.tempDir() as tempdir:
+ # Jump to that directory until this test is done.
+ with self.tempWorkingDir(tempdir):
+ # Save training snapshots to a relative path.
+ traindir = "train/"
+ os.mkdir(traindir)
+ # Collides with the default name of the checkpoint state file.
+ filepath = os.path.join(traindir, "checkpoint")
+
+ with self.test_session() as sess:
+ unused_a = variables.Variable(0.0) # So that Saver saves something.
+ variables.global_variables_initializer().run()
+
+ # Should fail.
+ saver = saver_module.Saver(sharded=False)
+ with self.assertRaisesRegexp(ValueError, "collides with"):
+ saver.save(sess, filepath)
+
+ # Succeeds: the file will be named "checkpoint-<step>".
+ saver.save(sess, filepath, global_step=1)
+ self.assertIsNotNone(
+ checkpoint_management.latest_checkpoint(traindir))
+
+ # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
+ saver = saver_module.Saver(sharded=True)
+ saver.save(sess, filepath)
+ self.assertIsNotNone(
+ checkpoint_management.latest_checkpoint(traindir))
+
+ # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
+ saver = saver_module.Saver(sharded=True)
+ saver.save(sess, filepath, global_step=1)
+ self.assertIsNotNone(
+ checkpoint_management.latest_checkpoint(traindir))
+
+ def testRelativePath(self):
+ # Make sure we have a clean directory to work in.
+ with self.tempDir() as tempdir:
+
+ # Jump to that directory until this test is done.
+ with self.tempWorkingDir(tempdir):
+
+ # Save training snapshots to a relative path.
+ traindir = "train/"
+ os.mkdir(traindir)
+
+ filename = "snapshot"
+ filepath = os.path.join(traindir, filename)
+
+ with self.test_session() as sess:
+ # Build a simple graph.
+ v0 = variables.Variable(0.0)
+ inc = v0.assign_add(1.0)
+
+ save = saver_module.Saver({"v0": v0})
+
+ # Record a short training history.
+ variables.global_variables_initializer().run()
+ save.save(sess, filepath, global_step=0)
+ inc.eval()
+ save.save(sess, filepath, global_step=1)
+ inc.eval()
+ save.save(sess, filepath, global_step=2)
+
+ with self.test_session() as sess:
+ # Build a new graph with different initialization.
+ v0 = variables.Variable(-1.0)
+
+ # Create a new saver.
+ save = saver_module.Saver({"v0": v0})
+ variables.global_variables_initializer().run()
+
+ # Get the most recent checkpoint name from the training history file.
+ name = checkpoint_management.latest_checkpoint(traindir)
+ self.assertIsNotNone(name)
+
+ # Restore "v0" from that checkpoint.
+ save.restore(sess, name)
+ self.assertEqual(v0.eval(), 2.0)
+
+
+class CheckpointStateTest(test.TestCase):
+
+ def _get_test_dir(self, dirname):
+ test_dir = os.path.join(self.get_temp_dir(), dirname)
+ gfile.MakeDirs(test_dir)
+ return test_dir
+
+ def testAbsPath(self):
+ save_dir = self._get_test_dir("abs_paths")
+ abs_path = os.path.join(save_dir, "model-0")
+ ckpt = checkpoint_management.generate_checkpoint_state_proto(
+ save_dir, abs_path)
+ self.assertEqual(ckpt.model_checkpoint_path, abs_path)
+ self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
+
+ def testRelPath(self):
+ train_dir = "train"
+ model = os.path.join(train_dir, "model-0")
+ # model_checkpoint_path should have no "train" directory part.
+ new_rel_path = "model-0"
+ ckpt = checkpoint_management.generate_checkpoint_state_proto(
+ train_dir, model)
+ self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
+
+ def testAllModelCheckpointPaths(self):
+ save_dir = self._get_test_dir("all_models_test")
+ abs_path = os.path.join(save_dir, "model-0")
+ for paths in [None, [], ["model-2"]]:
+ ckpt = checkpoint_management.generate_checkpoint_state_proto(
+ save_dir, abs_path, all_model_checkpoint_paths=paths)
+ self.assertEqual(ckpt.model_checkpoint_path, abs_path)
+ self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
+ self.assertEqual(
+ len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
+
+ def testUpdateCheckpointState(self):
+ save_dir = self._get_test_dir("update_checkpoint_state")
+ os.chdir(save_dir)
+ # Make a temporary train directory.
+ train_dir = "train"
+ os.mkdir(train_dir)
+ abs_path = os.path.join(save_dir, "model-0")
+ rel_path = os.path.join("train", "model-2")
+ checkpoint_management.update_checkpoint_state(
+ train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
+ ckpt = checkpoint_management.get_checkpoint_state(train_dir)
+ self.assertEqual(ckpt.model_checkpoint_path, rel_path)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
+
+ def testUpdateCheckpointStateSaveRelativePaths(self):
+ save_dir = self._get_test_dir("update_checkpoint_state")
+ os.chdir(save_dir)
+ abs_path2 = os.path.join(save_dir, "model-2")
+ rel_path2 = "model-2"
+ abs_path0 = os.path.join(save_dir, "model-0")
+ rel_path0 = "model-0"
+ checkpoint_management.update_checkpoint_state_internal(
+ save_dir=save_dir,
+ model_checkpoint_path=abs_path2,
+ all_model_checkpoint_paths=[rel_path0, abs_path2],
+ save_relative_paths=True)
+
+ # File should contain relative paths.
+ file_content = file_io.read_file_to_string(
+ os.path.join(save_dir, "checkpoint"))
+ ckpt = CheckpointState()
+ text_format.Merge(file_content, ckpt)
+ self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
+
+ # get_checkpoint_state should return absolute paths.
+ ckpt = checkpoint_management.get_checkpoint_state(save_dir)
+ self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
+ self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
+
+ def testCheckPointStateFailsWhenIncomplete(self):
+ save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
+ os.chdir(save_dir)
+ ckpt_path = os.path.join(save_dir, "checkpoint")
+ ckpt_file = open(ckpt_path, "w")
+ ckpt_file.write("")
+ ckpt_file.close()
+ with self.assertRaises(ValueError):
+ checkpoint_management.get_checkpoint_state(save_dir)
+
+ def testCheckPointCompletesRelativePaths(self):
+ save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
+ os.chdir(save_dir)
+ ckpt_path = os.path.join(save_dir, "checkpoint")
+ ckpt_file = open(ckpt_path, "w")
+ ckpt_file.write("""
+ model_checkpoint_path: "./model.ckpt-687529"
+ all_model_checkpoint_paths: "./model.ckpt-687500"
+ all_model_checkpoint_paths: "./model.ckpt-687529"
+ """)
+ ckpt_file.close()
+ ckpt = checkpoint_management.get_checkpoint_state(save_dir)
+ self.assertEqual(ckpt.model_checkpoint_path,
+ os.path.join(save_dir, "./model.ckpt-687529"))
+ self.assertEqual(ckpt.all_model_checkpoint_paths[0],
+ os.path.join(save_dir, "./model.ckpt-687500"))
+ self.assertEqual(ckpt.all_model_checkpoint_paths[1],
+ os.path.join(save_dir, "./model.ckpt-687529"))
+
+
+class SaverUtilsTest(test.TestCase):
+
+ def setUp(self):
+ self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
+ gfile.MakeDirs(self._base_dir)
+
+ def tearDown(self):
+ gfile.DeleteRecursively(self._base_dir)
+
+ def testCheckpointExists(self):
+ for sharded in (False, True):
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(sharded=sharded, write_version=version)
+
+ path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
+ self.assertFalse(
+ checkpoint_management.checkpoint_exists(path)) # Not saved yet.
+
+ ckpt_prefix = saver.save(sess, path)
+ self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
+
+ ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
+ self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
+
+ def testGetCheckpointMtimes(self):
+ prefixes = []
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(write_version=version)
+ prefixes.append(
+ saver.save(sess, os.path.join(self._base_dir, str(version))))
+
+ mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes)
+ self.assertEqual(2, len(mtimes))
+ self.assertTrue(mtimes[1] >= mtimes[0])
+
+ def testRemoveCheckpoint(self):
+ for sharded in (False, True):
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(sharded=sharded, write_version=version)
+
+ path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
+ ckpt_prefix = saver.save(sess, path)
+ self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
+ checkpoint_management.remove_checkpoint(ckpt_prefix, version)
+ self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index a052081630..9b72b09f08 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -277,7 +278,7 @@ def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map):
def _get_checkpoint_filename(ckpt_dir_or_file):
"""Returns checkpoint filename given directory or specific checkpoint file."""
if gfile.IsDirectory(ckpt_dir_or_file):
- return saver.latest_checkpoint(ckpt_dir_or_file)
+ return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
return ckpt_dir_or_file
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 35007653a0..8a289b31b5 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -124,14 +124,18 @@ py_test(
],
deps = [
":base",
+ ":tracking",
":util",
+ "//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:saver",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:template",
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 3c1a4a6f83..5506e6bc4e 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base
@@ -467,7 +468,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
- root.restore(saver_lib.latest_checkpoint(checkpoint_directory))
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
@@ -495,7 +497,8 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
@@ -528,7 +531,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
@@ -561,7 +565,8 @@ class CheckpointingTests(test.TestCase):
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun
@@ -1180,7 +1185,8 @@ class CheckpointingTests(test.TestCase):
optimizer_checkpoint = checkpointable_utils.Checkpoint(
optimizer=optimizer)
- checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 3806056f01..92533ca4f3 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@@ -1364,8 +1365,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
- checkpoint_filename_with_path=saver_lib.latest_checkpoint(
- logdir))) as session:
+ checkpoint_filename_with_path=
+ checkpoint_management.latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index c80cdf03be..213c11c50d 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -21,15 +21,12 @@ from __future__ import print_function
import collections
import os.path
-import re
import time
import uuid
import numpy as np
import six
-from google.protobuf import text_format
-
from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
@@ -41,7 +38,6 @@ from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
-from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
@@ -52,14 +48,25 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saveable_object
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
+# TODO(allenl): Remove these aliases once all users are migrated off.
+get_checkpoint_state = checkpoint_management.get_checkpoint_state
+update_checkpoint_state = checkpoint_management.update_checkpoint_state
+generate_checkpoint_state_proto = (
+ checkpoint_management.generate_checkpoint_state_proto)
+latest_checkpoint = checkpoint_management.latest_checkpoint
+checkpoint_exists = checkpoint_management.checkpoint_exists
+get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
+remove_checkpoint = checkpoint_management.remove_checkpoint
+
+
# Op names which identify variable reads which should be saved.
_VARIABLE_OPS = set(["Variable",
"VariableV2",
@@ -858,218 +865,6 @@ def _get_saver_or_default():
return saver
-def _GetCheckpointFilename(save_dir, latest_filename):
- """Returns a filename for storing the CheckpointState.
-
- Args:
- save_dir: The directory for saving and restoring checkpoints.
- latest_filename: Name of the file in 'save_dir' that is used
- to store the CheckpointState.
-
- Returns:
- The path of the file that contains the CheckpointState proto.
- """
- if latest_filename is None:
- latest_filename = "checkpoint"
- return os.path.join(save_dir, latest_filename)
-
-
-@tf_export("train.generate_checkpoint_state_proto")
-def generate_checkpoint_state_proto(save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=None):
- """Generates a checkpoint state proto.
-
- Args:
- save_dir: Directory where the model was saved.
- model_checkpoint_path: The checkpoint file.
- all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
- checkpoints, sorted from oldest to newest. If this is a non-empty list,
- the last element must be equal to model_checkpoint_path. These paths
- are also saved in the CheckpointState proto.
-
- Returns:
- CheckpointState proto with model_checkpoint_path and
- all_model_checkpoint_paths updated to either absolute paths or
- relative paths to the current save_dir.
- """
- if all_model_checkpoint_paths is None:
- all_model_checkpoint_paths = []
-
- if (not all_model_checkpoint_paths or
- all_model_checkpoint_paths[-1] != model_checkpoint_path):
- logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
- model_checkpoint_path)
- all_model_checkpoint_paths.append(model_checkpoint_path)
-
- # Relative paths need to be rewritten to be relative to the "save_dir"
- # if model_checkpoint_path already contains "save_dir".
- if not os.path.isabs(save_dir):
- if not os.path.isabs(model_checkpoint_path):
- model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
- for i in range(len(all_model_checkpoint_paths)):
- p = all_model_checkpoint_paths[i]
- if not os.path.isabs(p):
- all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
-
- coord_checkpoint_proto = CheckpointState(
- model_checkpoint_path=model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
-
- return coord_checkpoint_proto
-
-
-@tf_export("train.update_checkpoint_state")
-def update_checkpoint_state(save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=None,
- latest_filename=None):
- """Updates the content of the 'checkpoint' file.
-
- This updates the checkpoint file containing a CheckpointState
- proto.
-
- Args:
- save_dir: Directory where the model was saved.
- model_checkpoint_path: The checkpoint file.
- all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
- checkpoints, sorted from oldest to newest. If this is a non-empty list,
- the last element must be equal to model_checkpoint_path. These paths
- are also saved in the CheckpointState proto.
- latest_filename: Optional name of the checkpoint file. Default to
- 'checkpoint'.
-
- Raises:
- RuntimeError: If any of the model checkpoint paths conflict with the file
- containing CheckpointSate.
- """
- _update_checkpoint_state(
- save_dir=save_dir,
- model_checkpoint_path=model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths,
- latest_filename=latest_filename,
- save_relative_paths=False)
-
-
-def _update_checkpoint_state(save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=None,
- latest_filename=None,
- save_relative_paths=False):
- """Updates the content of the 'checkpoint' file.
-
- This updates the checkpoint file containing a CheckpointState
- proto.
-
- Args:
- save_dir: Directory where the model was saved.
- model_checkpoint_path: The checkpoint file.
- all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
- checkpoints, sorted from oldest to newest. If this is a non-empty list,
- the last element must be equal to model_checkpoint_path. These paths
- are also saved in the CheckpointState proto.
- latest_filename: Optional name of the checkpoint file. Default to
- 'checkpoint'.
- save_relative_paths: If `True`, will write relative paths to the checkpoint
- state file.
-
- Raises:
- RuntimeError: If any of the model checkpoint paths conflict with the file
- containing CheckpointSate.
- """
- # Writes the "checkpoint" file for the coordinator for later restoration.
- coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
- if save_relative_paths:
- if os.path.isabs(model_checkpoint_path):
- rel_model_checkpoint_path = os.path.relpath(
- model_checkpoint_path, save_dir)
- else:
- rel_model_checkpoint_path = model_checkpoint_path
- rel_all_model_checkpoint_paths = []
- for p in all_model_checkpoint_paths:
- if os.path.isabs(p):
- rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
- else:
- rel_all_model_checkpoint_paths.append(p)
- ckpt = generate_checkpoint_state_proto(
- save_dir,
- rel_model_checkpoint_path,
- all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
- else:
- ckpt = generate_checkpoint_state_proto(
- save_dir,
- model_checkpoint_path,
- all_model_checkpoint_paths=all_model_checkpoint_paths)
-
- if coord_checkpoint_filename == ckpt.model_checkpoint_path:
- raise RuntimeError("Save path '%s' conflicts with path used for "
- "checkpoint state. Please use a different save path." %
- model_checkpoint_path)
-
- # Preventing potential read/write race condition by *atomically* writing to a
- # file.
- file_io.atomic_write_string_to_file(coord_checkpoint_filename,
- text_format.MessageToString(ckpt))
-
-
-@tf_export("train.get_checkpoint_state")
-def get_checkpoint_state(checkpoint_dir, latest_filename=None):
- """Returns CheckpointState proto from the "checkpoint" file.
-
- If the "checkpoint" file contains a valid CheckpointState
- proto, returns it.
-
- Args:
- checkpoint_dir: The directory of checkpoints.
- latest_filename: Optional name of the checkpoint file. Default to
- 'checkpoint'.
-
- Returns:
- A CheckpointState if the state was available, None
- otherwise.
-
- Raises:
- ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
- """
- ckpt = None
- coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
- latest_filename)
- f = None
- try:
- # Check that the file exists before opening it to avoid
- # many lines of errors from colossus in the logs.
- if file_io.file_exists(coord_checkpoint_filename):
- file_content = file_io.read_file_to_string(
- coord_checkpoint_filename)
- ckpt = CheckpointState()
- text_format.Merge(file_content, ckpt)
- if not ckpt.model_checkpoint_path:
- raise ValueError("Invalid checkpoint state loaded from "
- + checkpoint_dir)
- # For relative model_checkpoint_path and all_model_checkpoint_paths,
- # prepend checkpoint_dir.
- if not os.path.isabs(ckpt.model_checkpoint_path):
- ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
- ckpt.model_checkpoint_path)
- for i in range(len(ckpt.all_model_checkpoint_paths)):
- p = ckpt.all_model_checkpoint_paths[i]
- if not os.path.isabs(p):
- ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
- except errors.OpError as e:
- # It's ok if the file cannot be read
- logging.warning("%s: %s", type(e).__name__, e)
- logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
- return None
- except text_format.ParseError as e:
- logging.warning("%s: %s", type(e).__name__, e)
- logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
- return None
- finally:
- if f:
- f.close()
- return ckpt
-
-
@tf_export("train.Saver")
class Saver(object):
"""Saves and restores variables.
@@ -1412,7 +1207,7 @@ class Saver(object):
# Otherwise delete the files.
try:
- remove_checkpoint(
+ checkpoint_management.remove_checkpoint(
self._CheckpointFilename(p), self.saver_def.version,
meta_graph_suffix)
except Exception as e: # pylint: disable=broad-except
@@ -1518,7 +1313,7 @@ class Saver(object):
Args:
checkpoint_paths: a list of checkpoint paths.
"""
- mtimes = get_checkpoint_mtimes(checkpoint_paths)
+ mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
def save(self,
@@ -1624,7 +1419,7 @@ class Saver(object):
model_checkpoint_path = compat.as_str(model_checkpoint_path)
if write_state:
self._RecordLastCheckpoint(model_checkpoint_path)
- _update_checkpoint_state(
+ checkpoint_management.update_checkpoint_state_internal(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
@@ -1639,7 +1434,7 @@ class Saver(object):
raise exc
if write_meta_graph:
- meta_graph_filename = _meta_graph_filename(
+ meta_graph_filename = checkpoint_management.meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
@@ -1714,7 +1509,7 @@ class Saver(object):
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
- if not checkpoint_exists(compat.as_text(save_path)):
+ if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: "
+ compat.as_text(save_path))
@@ -1800,55 +1595,6 @@ class Saver(object):
export_scope=export_scope)
-def _prefix_to_checkpoint_path(prefix, format_version):
- """Returns the pathname of a checkpoint file, given the checkpoint prefix.
-
- For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
- returns the pathname to the index file.
-
- Args:
- prefix: a string, the prefix of a checkpoint.
- format_version: the checkpoint format version that corresponds to the
- prefix.
- Returns:
- The pathname of a checkpoint file, taking into account the checkpoint
- format version.
- """
- if format_version == saver_pb2.SaverDef.V2:
- return prefix + ".index" # The index file identifies a checkpoint.
- return prefix # Just the data file.
-
-
-@tf_export("train.latest_checkpoint")
-def latest_checkpoint(checkpoint_dir, latest_filename=None):
- """Finds the filename of latest saved checkpoint file.
-
- Args:
- checkpoint_dir: Directory where the variables were saved.
- latest_filename: Optional name for the protocol buffer file that
- contains the list of most recent checkpoint filenames.
- See the corresponding argument to `Saver.save()`.
-
- Returns:
- The full path to the latest checkpoint or `None` if no checkpoint was found.
- """
- # Pick the latest checkpoint based on checkpoint state.
- ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
- if ckpt and ckpt.model_checkpoint_path:
- # Look for either a V2 path or a V1 path, with priority for V2.
- v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
- saver_pb2.SaverDef.V2)
- v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
- saver_pb2.SaverDef.V1)
- if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
- v1_path):
- return ckpt.model_checkpoint_path
- else:
- logging.error("Couldn't match files for checkpoint %s",
- ckpt.model_checkpoint_path)
- return None
-
-
@tf_export("train.import_meta_graph")
def import_meta_graph(meta_graph_or_file, clear_devices=False,
import_scope=None, **kwargs):
@@ -2056,119 +1802,6 @@ def export_meta_graph(filename=None,
return meta_graph_def
-@tf_export("train.checkpoint_exists")
-def checkpoint_exists(checkpoint_prefix):
- """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
-
- This is the recommended way to check if a checkpoint exists, since it takes
- into account the naming difference between V1 and V2 formats.
-
- Args:
- checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
- priority. Typically the result of `Saver.save()` or that of
- `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
- V1/V2.
- Returns:
- A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
- """
- pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
- saver_pb2.SaverDef.V2)
- if file_io.get_matching_files(pathname):
- return True
- elif file_io.get_matching_files(checkpoint_prefix):
- return True
- else:
- return False
-
-
-@tf_export("train.get_checkpoint_mtimes")
-def get_checkpoint_mtimes(checkpoint_prefixes):
- """Returns the mtimes (modification timestamps) of the checkpoints.
-
- Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files
- exist, collect their mtime. Both V2 and V1 checkpoints are considered, in
- that priority.
-
- This is the recommended way to get the mtimes, since it takes into account
- the naming difference between V1 and V2 formats.
-
- Args:
- checkpoint_prefixes: a list of checkpoint paths, typically the results of
- `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
- sharded/non-sharded or V1/V2.
- Returns:
- A list of mtimes (in microseconds) of the found checkpoints.
- """
- mtimes = []
-
- def match_maybe_append(pathname):
- fnames = file_io.get_matching_files(pathname)
- if fnames:
- mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
- return True
- return False
-
- for checkpoint_prefix in checkpoint_prefixes:
- # Tries V2's metadata file first.
- pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
- saver_pb2.SaverDef.V2)
- if match_maybe_append(pathname):
- continue
- # Otherwise, tries V1, where the prefix is the complete pathname.
- match_maybe_append(checkpoint_prefix)
-
- return mtimes
-
-
-@tf_export("train.remove_checkpoint")
-def remove_checkpoint(checkpoint_prefix,
- checkpoint_format_version=saver_pb2.SaverDef.V2,
- meta_graph_suffix="meta"):
- """Removes a checkpoint given by `checkpoint_prefix`.
-
- Args:
- checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
- of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
- sharded/non-sharded or V1/V2.
- checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
- `SaverDef.V2`.
- meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
- """
- _delete_file_if_exists(
- _meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
- if checkpoint_format_version == saver_pb2.SaverDef.V2:
- # V2 has a metadata file and some data files.
- _delete_file_if_exists(checkpoint_prefix + ".index")
- _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
- else:
- # V1, Legacy. Exact match on the data file.
- _delete_file_if_exists(checkpoint_prefix)
-
-
-def _delete_file_if_exists(filespec):
- """Deletes files matching `filespec`."""
- for pathname in file_io.get_matching_files(filespec):
- file_io.delete_file(pathname)
-
-
-def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
- """Returns the meta graph filename.
-
- Args:
- checkpoint_filename: Name of the checkpoint file.
- meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
-
- Returns:
- MetaGraph file name.
- """
- # If the checkpoint_filename is sharded, the checkpoint_filename could
- # be of format model.ckpt-step#-?????-of-shard#. For example,
- # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
- basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
- meta_graph_filename = ".".join([basename, meta_graph_suffix])
- return meta_graph_filename
-
-
def _wrap_restore_error_with_msg(err, extra_verbiage):
err_msg = ("Restoring from checkpoint failed. This is most likely "
"due to {} from the checkpoint. Please ensure that you "
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 204e81dda0..941aafc780 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -18,20 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import contextlib
import functools
import math
import os
import random
-import shutil
-import tempfile
import time
import numpy as np
import six
from google.protobuf.any_pb2 import Any
-from google.protobuf import text_format
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
@@ -71,12 +67,12 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.training.checkpointable import base as checkpointable_base
from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -343,11 +339,13 @@ class SaverTest(test.TestCase):
self.assertTrue(isinstance(val, six.string_types))
self.assertEqual(save_path1, val)
- self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)
+ self.assertEqual(
+ checkpoint_management.latest_checkpoint(save_dir1), save_path1)
save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
os.renames(save_dir1, save_dir2)
save_path2 = os.path.join(save_dir2, "save_copy_restore")
- self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)
+ self.assertEqual(
+ checkpoint_management.latest_checkpoint(save_dir2), save_path2)
# Start a second session. In that session the parameter nodes
# have not been initialized either.
@@ -857,7 +855,7 @@ class SaveRestoreShardedTest(test.TestCase):
self.assertEqual(save_path + "-?????-of-00002", val)
else:
self.assertEqual(save_path, val)
- meta_graph_filename = saver_module._meta_graph_filename(val)
+ meta_graph_filename = checkpoint_management.meta_graph_filename(val)
self.assertEqual(save_path + ".meta", meta_graph_filename)
if save._write_version is saver_pb2.SaverDef.V1:
@@ -951,11 +949,11 @@ class SaveRestoreShardedTest(test.TestCase):
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(
- saver_module.latest_checkpoint(self.get_temp_dir()),
+ checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
else:
self.assertEqual(
- saver_module.latest_checkpoint(self.get_temp_dir()),
+ checkpoint_management.latest_checkpoint(self.get_temp_dir()),
os.path.join(self.get_temp_dir(), "sharded_basics"))
def testSaverDef(self):
@@ -1105,7 +1103,7 @@ class MaxToKeepTest(test.TestCase):
def assertCheckpointState(self, model_checkpoint_path,
all_model_checkpoint_paths, save_dir):
- checkpoint_state = saver_module.get_checkpoint_state(save_dir)
+ checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
self.assertEqual(checkpoint_state.model_checkpoint_path,
model_checkpoint_path)
self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
@@ -1113,7 +1111,7 @@ class MaxToKeepTest(test.TestCase):
def testMaxToKeepEager(self):
with context.eager_mode():
- save_dir = self._get_test_dir("max_to_keep_non_sharded")
+ save_dir = self._get_test_dir("max_to_keep_eager")
v = variable_scope.variable(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
@@ -1123,7 +1121,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
@@ -1131,8 +1129,8 @@ class MaxToKeepTest(test.TestCase):
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
@@ -1140,9 +1138,9 @@ class MaxToKeepTest(test.TestCase):
s3 = save.save(None, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
- self.assertTrue(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
@@ -1157,9 +1155,9 @@ class MaxToKeepTest(test.TestCase):
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s3))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1168,8 +1166,8 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(None, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1178,9 +1176,9 @@ class MaxToKeepTest(test.TestCase):
s2 = save2.save(None, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
# Deleted by the first helper.
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
def testNonSharded(self):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
@@ -1193,7 +1191,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s1], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s1],
@@ -1201,8 +1199,8 @@ class MaxToKeepTest(test.TestCase):
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s1, s2],
@@ -1210,9 +1208,9 @@ class MaxToKeepTest(test.TestCase):
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
- self.assertTrue(saver_module.checkpoint_exists(s2))
- self.assertTrue(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertCheckpointState(
model_checkpoint_path=s3,
all_model_checkpoint_paths=[s2, s3],
@@ -1231,15 +1229,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s2 again (old s2 is removed first, then new s2 appended)
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s1))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s1))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
- self.assertTrue(saver_module.checkpoint_exists(s3))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1248,15 +1249,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1268,16 +1272,19 @@ class MaxToKeepTest(test.TestCase):
s2 = save2.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s3, s2], save2.last_checkpoints)
# Created by the first helper.
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1286,15 +1293,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should now be deleted as oldest in list)
s1 = save2.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save2.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1306,16 +1316,19 @@ class MaxToKeepTest(test.TestCase):
s2 = save3.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s2], save3.last_checkpoints)
# Created by the first helper.
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
# Deleted by the first helper.
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
# Even though the file for s1 exists, this saver isn't aware of it, which
# is why it doesn't end up in the checkpoint state.
self.assertCheckpointState(
@@ -1326,15 +1339,18 @@ class MaxToKeepTest(test.TestCase):
# Adding s1 (s3 should not be deleted because helper is unaware of it)
s1 = save3.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([s2, s1], save3.last_checkpoints)
- self.assertFalse(saver_module.checkpoint_exists(s3))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s3)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s2)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
+ checkpoint_management.checkpoint_exists(
+ checkpoint_management.meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1365,7 +1381,8 @@ class MaxToKeepTest(test.TestCase):
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
@@ -1373,27 +1390,32 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
- self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertFalse(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s3)))
else:
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
- self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
+ self.assertTrue(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
def testNoMaxToKeep(self):
save_dir = self._get_test_dir("no_max_to_keep")
@@ -1408,20 +1430,20 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([], save.last_checkpoints)
s1 = save.save(sess, os.path.join(save_dir, "s1"))
self.assertEqual([], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([], save.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
# Test max_to_keep being 0.
save2 = saver_module.Saver({"v": v}, max_to_keep=0)
self.assertEqual([], save2.last_checkpoints)
s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
self.assertEqual([], save2.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s1))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
self.assertEqual([], save2.last_checkpoints)
- self.assertTrue(saver_module.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s2))
def testNoMetaGraph(self):
save_dir = self._get_test_dir("no_meta_graph")
@@ -1432,8 +1454,9 @@ class MaxToKeepTest(test.TestCase):
variables.global_variables_initializer().run()
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertFalse(
+ gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
class KeepCheckpointEveryNHoursTest(test.TestCase):
@@ -1489,10 +1512,10 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
self.assertEqual([s3, s4], save.last_checkpoints)
# Check that s1 is still here, but s2 is gone.
- self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertFalse(saver_module.checkpoint_exists(s2))
- self.assertTrue(saver_module.checkpoint_exists(s3))
- self.assertTrue(saver_module.checkpoint_exists(s4))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s1))
+ self.assertFalse(checkpoint_management.checkpoint_exists(s2))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s3))
+ self.assertTrue(checkpoint_management.checkpoint_exists(s4))
class SaveRestoreWithVariableNameMap(test.TestCase):
@@ -1571,221 +1594,6 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
self._testNonReshape(variables.Variable)
-class LatestCheckpointWithRelativePaths(test.TestCase):
-
- @staticmethod
- @contextlib.contextmanager
- def tempWorkingDir(temppath):
- cwd = os.getcwd()
- os.chdir(temppath)
- try:
- yield
- finally:
- os.chdir(cwd)
-
- @staticmethod
- @contextlib.contextmanager
- def tempDir():
- tempdir = tempfile.mkdtemp()
- try:
- yield tempdir
- finally:
- shutil.rmtree(tempdir)
-
- def testNameCollision(self):
- # Make sure we have a clean directory to work in.
- with self.tempDir() as tempdir:
- # Jump to that directory until this test is done.
- with self.tempWorkingDir(tempdir):
- # Save training snapshots to a relative path.
- traindir = "train/"
- os.mkdir(traindir)
- # Collides with the default name of the checkpoint state file.
- filepath = os.path.join(traindir, "checkpoint")
-
- with self.test_session() as sess:
- unused_a = variables.Variable(0.0) # So that Saver saves something.
- variables.global_variables_initializer().run()
-
- # Should fail.
- saver = saver_module.Saver(sharded=False)
- with self.assertRaisesRegexp(ValueError, "collides with"):
- saver.save(sess, filepath)
-
- # Succeeds: the file will be named "checkpoint-<step>".
- saver.save(sess, filepath, global_step=1)
- self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
-
- # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
- saver = saver_module.Saver(sharded=True)
- saver.save(sess, filepath)
- self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
-
- # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
- saver = saver_module.Saver(sharded=True)
- saver.save(sess, filepath, global_step=1)
- self.assertIsNotNone(saver_module.latest_checkpoint(traindir))
-
- def testRelativePath(self):
- # Make sure we have a clean directory to work in.
- with self.tempDir() as tempdir:
-
- # Jump to that directory until this test is done.
- with self.tempWorkingDir(tempdir):
-
- # Save training snapshots to a relative path.
- traindir = "train/"
- os.mkdir(traindir)
-
- filename = "snapshot"
- filepath = os.path.join(traindir, filename)
-
- with self.test_session() as sess:
- # Build a simple graph.
- v0 = variables.Variable(0.0)
- inc = v0.assign_add(1.0)
-
- save = saver_module.Saver({"v0": v0})
-
- # Record a short training history.
- variables.global_variables_initializer().run()
- save.save(sess, filepath, global_step=0)
- inc.eval()
- save.save(sess, filepath, global_step=1)
- inc.eval()
- save.save(sess, filepath, global_step=2)
-
- with self.test_session() as sess:
- # Build a new graph with different initialization.
- v0 = variables.Variable(-1.0)
-
- # Create a new saver.
- save = saver_module.Saver({"v0": v0})
- variables.global_variables_initializer().run()
-
- # Get the most recent checkpoint name from the training history file.
- name = saver_module.latest_checkpoint(traindir)
- self.assertIsNotNone(name)
-
- # Restore "v0" from that checkpoint.
- save.restore(sess, name)
- self.assertEqual(v0.eval(), 2.0)
-
-
-class CheckpointStateTest(test.TestCase):
-
- def _get_test_dir(self, dirname):
- test_dir = os.path.join(self.get_temp_dir(), dirname)
- gfile.MakeDirs(test_dir)
- return test_dir
-
- def testAbsPath(self):
- save_dir = self._get_test_dir("abs_paths")
- abs_path = os.path.join(save_dir, "model-0")
- ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)
- self.assertEqual(ckpt.model_checkpoint_path, abs_path)
- self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
-
- def testRelPath(self):
- train_dir = "train"
- model = os.path.join(train_dir, "model-0")
- # model_checkpoint_path should have no "train" directory part.
- new_rel_path = "model-0"
- ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)
- self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
-
- def testAllModelCheckpointPaths(self):
- save_dir = self._get_test_dir("all_models_test")
- abs_path = os.path.join(save_dir, "model-0")
- for paths in [None, [], ["model-2"]]:
- ckpt = saver_module.generate_checkpoint_state_proto(
- save_dir, abs_path, all_model_checkpoint_paths=paths)
- self.assertEqual(ckpt.model_checkpoint_path, abs_path)
- self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
- self.assertEqual(
- len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
-
- def testUpdateCheckpointState(self):
- save_dir = self._get_test_dir("update_checkpoint_state")
- os.chdir(save_dir)
- # Make a temporary train directory.
- train_dir = "train"
- os.mkdir(train_dir)
- abs_path = os.path.join(save_dir, "model-0")
- rel_path = os.path.join("train", "model-2")
- saver_module.update_checkpoint_state(
- train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
- ckpt = saver_module.get_checkpoint_state(train_dir)
- self.assertEqual(ckpt.model_checkpoint_path, rel_path)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
- self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
-
- def testUpdateCheckpointStateSaveRelativePaths(self):
- save_dir = self._get_test_dir("update_checkpoint_state")
- os.chdir(save_dir)
- abs_path2 = os.path.join(save_dir, "model-2")
- rel_path2 = "model-2"
- abs_path0 = os.path.join(save_dir, "model-0")
- rel_path0 = "model-0"
- saver_module._update_checkpoint_state( # pylint: disable=protected-access
- save_dir=save_dir,
- model_checkpoint_path=abs_path2,
- all_model_checkpoint_paths=[rel_path0, abs_path2],
- save_relative_paths=True)
-
- # File should contain relative paths.
- file_content = file_io.read_file_to_string(
- os.path.join(save_dir, "checkpoint"))
- ckpt = CheckpointState()
- text_format.Merge(file_content, ckpt)
- self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
-
- # get_checkpoint_state should return absolute paths.
- ckpt = saver_module.get_checkpoint_state(save_dir)
- self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
- self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
- self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
-
- def testCheckPointStateFailsWhenIncomplete(self):
- save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
- os.chdir(save_dir)
- ckpt_path = os.path.join(save_dir, "checkpoint")
- ckpt_file = open(ckpt_path, "w")
- ckpt_file.write("")
- ckpt_file.close()
- with self.assertRaises(ValueError):
- saver_module.get_checkpoint_state(save_dir)
-
- def testCheckPointCompletesRelativePaths(self):
- save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
- os.chdir(save_dir)
- ckpt_path = os.path.join(save_dir, "checkpoint")
- ckpt_file = open(ckpt_path, "w")
- ckpt_file.write("""
- model_checkpoint_path: "./model.ckpt-687529"
- all_model_checkpoint_paths: "./model.ckpt-687500"
- all_model_checkpoint_paths: "./model.ckpt-687529"
- """)
- ckpt_file.close()
- ckpt = saver_module.get_checkpoint_state(save_dir)
- self.assertEqual(ckpt.model_checkpoint_path,
- os.path.join(save_dir, "./model.ckpt-687529"))
- self.assertEqual(ckpt.all_model_checkpoint_paths[0],
- os.path.join(save_dir, "./model.ckpt-687500"))
- self.assertEqual(ckpt.all_model_checkpoint_paths[1],
- os.path.join(save_dir, "./model.ckpt-687529"))
-
-
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
@@ -2628,62 +2436,6 @@ class WriteGraphTest(test.TestCase):
self.assertTrue(os.path.exists(path))
-class SaverUtilsTest(test.TestCase):
-
- def setUp(self):
- self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
- gfile.MakeDirs(self._base_dir)
-
- def tearDown(self):
- gfile.DeleteRecursively(self._base_dir)
-
- def testCheckpointExists(self):
- for sharded in (False, True):
- for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
- unused_v = variables.Variable(1.0, name="v")
- variables.global_variables_initializer().run()
- saver = saver_module.Saver(sharded=sharded, write_version=version)
-
- path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
- self.assertFalse(
- saver_module.checkpoint_exists(path)) # Not saved yet.
-
- ckpt_prefix = saver.save(sess, path)
- self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
-
- ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)
- self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
-
- def testGetCheckpointMtimes(self):
- prefixes = []
- for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
- unused_v = variables.Variable(1.0, name="v")
- variables.global_variables_initializer().run()
- saver = saver_module.Saver(write_version=version)
- prefixes.append(
- saver.save(sess, os.path.join(self._base_dir, str(version))))
-
- mtimes = saver_module.get_checkpoint_mtimes(prefixes)
- self.assertEqual(2, len(mtimes))
- self.assertTrue(mtimes[1] >= mtimes[0])
-
- def testRemoveCheckpoint(self):
- for sharded in (False, True):
- for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
- with self.test_session(graph=ops_lib.Graph()) as sess:
- unused_v = variables.Variable(1.0, name="v")
- variables.global_variables_initializer().run()
- saver = saver_module.Saver(sharded=sharded, write_version=version)
-
- path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
- ckpt_prefix = saver.save(sess, path)
- self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
- saver_module.remove_checkpoint(ckpt_prefix, version)
- self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
-
-
class ScopedGraphTest(test.TestCase):
def _get_test_dir(self, dirname):
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index 974f75777f..a2e0645ba8 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -24,7 +24,7 @@ from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import saver as saver_mod
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.util.tf_export import tf_export
@@ -197,13 +197,13 @@ class SessionManager(object):
# Waits up until max_wait_secs for checkpoint to become available.
wait_time = 0
- ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
+ ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
while not ckpt or not ckpt.model_checkpoint_path:
if wait_for_checkpoint and wait_time < max_wait_secs:
logging.info("Waiting for checkpoint to be available.")
time.sleep(self._recovery_wait_secs)
wait_time += self._recovery_wait_secs
- ckpt = saver_mod.get_checkpoint_state(checkpoint_dir)
+ ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
else:
return sess, False
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index 6670d9365f..d7e6dac95b 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_manager
@@ -174,13 +175,13 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
- checkpoint_filename_with_path=saver_lib.latest_checkpoint(
+ checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
with self.assertRaises(ValueError):
self._test_recovered_variable(
checkpoint_dir=checkpoint_dir,
- checkpoint_filename_with_path=saver_lib.latest_checkpoint(
+ checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
def testWaitForSessionReturnsNoneAfterTimeout(self):
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 4abce85852..71ed88093a 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
@@ -83,7 +84,7 @@ class SupervisorTest(test.TestCase):
end_time = time.time() + timeout_secs
while time.time() < end_time:
if for_checkpoint:
- if saver_lib.checkpoint_exists(pattern):
+ if checkpoint_management.checkpoint_exists(pattern):
return
else:
if len(gfile.Glob(pattern)) >= 1:
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index 3f2dc67976..544010afbe 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -82,12 +82,12 @@ from tensorflow.python.training.monitored_session import WorkerSessionCreator
from tensorflow.python.training.monitored_session import MonitoredSession
from tensorflow.python.training.monitored_session import SingularMonitoredSession
from tensorflow.python.training.saver import Saver
-from tensorflow.python.training.saver import checkpoint_exists
-from tensorflow.python.training.saver import generate_checkpoint_state_proto
-from tensorflow.python.training.saver import get_checkpoint_mtimes
-from tensorflow.python.training.saver import get_checkpoint_state
-from tensorflow.python.training.saver import latest_checkpoint
-from tensorflow.python.training.saver import update_checkpoint_state
+from tensorflow.python.training.checkpoint_management import checkpoint_exists
+from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto
+from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes
+from tensorflow.python.training.checkpoint_management import get_checkpoint_state
+from tensorflow.python.training.checkpoint_management import latest_checkpoint
+from tensorflow.python.training.checkpoint_management import update_checkpoint_state
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.training.saver import import_meta_graph
from tensorflow.python.training.session_run_hook import SessionRunHook
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 0877b2a8a2..2ff3eeb153 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -44,11 +44,13 @@ def global_step(sess, global_step_tensor):
"""Small helper to get the global step.
```python
- # Creates a variable to hold the global_step.
+ # Create a variable to hold the global_step.
global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
- # Creates a session.
+ # Create a session.
sess = tf.Session()
- # Initializes the variable.
+ # Initialize the variable
+ sess.run(global_step_tensor.initializer)
+ # Get the variable value.
print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))
global_step: 10
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 9e2202eaf8..74e1fb227f 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -388,7 +388,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
Args:
names_to_ok_vals: dict from string arg_name to a list of values,
possibly empty, which should not elicit a warning.
- arg_spec: Output from tf_inspect.getargspec on the called function.
+ arg_spec: Output from tf_inspect.getfullargspec on the called function.
Returns:
Dictionary from arg_name to DeprecatedArgSpec.
@@ -408,16 +408,16 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
decorator_utils.validate_callable(func, 'deprecated_args')
deprecated_arg_names = _get_arg_names_to_ok_vals()
- arg_spec = tf_inspect.getargspec(func)
+ arg_spec = tf_inspect.getfullargspec(func)
deprecated_positions = _get_deprecated_positional_arguments(
deprecated_arg_names, arg_spec)
is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
- is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names
+ is_kwargs_deprecated = arg_spec.varkw in deprecated_arg_names
if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated
!= len(deprecated_arg_names_or_tuples)):
- known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords]
+ known_args = arg_spec.args + [arg_spec.varargs, arg_spec.varkw]
missing_args = [arg_name for arg_name in deprecated_arg_names
if arg_name not in known_args]
raise ValueError('The following deprecated arguments are not present '
@@ -467,7 +467,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
if is_varargs_deprecated and len(args) > len(arg_spec.args):
invalid_args.append(arg_spec.varargs)
if is_kwargs_deprecated and kwargs:
- invalid_args.append(arg_spec.keywords)
+ invalid_args.append(arg_spec.varkw)
for arg_name in deprecated_arg_names:
if (arg_name in kwargs and
not (deprecated_positions[arg_name].has_ok_value and
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index fd75c6885a..2369eb610e 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -354,6 +354,10 @@ class NestTest(parameterized.TestCase, test.TestCase):
EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
+ def testHeterogeneousComparison(self):
+ nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
+ nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
+
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index ec20998bdd..778121e15b 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -184,7 +184,7 @@ else:
Returns:
A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
"""
- argspecs = _inspect.getargspec(target)
+ argspecs = getargspec(target)
fullargspecs = FullArgSpec(
args=argspecs.args,
varargs=argspecs.varargs,
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index 2f6021c7d8..d3b7e4b969 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -122,6 +122,18 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
+ def testGetFullArgsSpecForPartial(self):
+
+ def func(a, b):
+ del a, b
+
+ partial_function = functools.partial(func, 1)
+ argspec = tf_inspect.FullArgSpec(
+ args=['b'], varargs=None, varkw=None, defaults=None,
+ kwonlyargs=[], kwonlydefaults=None, annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
+
def testGetArgSpecOnPartialInvalidArgspec(self):
"""Tests getargspec on partial function that doesn't have valid argspec."""
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index ad85a44f8d..ebb72079ef 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -52,12 +52,17 @@ bool IsString(PyObject* o) {
// returned value is a list.
//
// As with PyMapping_Keys, returns a new reference.
+//
+// On failure, returns nullptr.
PyObject* MappingKeys(PyObject* o) {
#if PY_MAJOR_VERSION >= 3
return PyMapping_Keys(o);
#else
static char key_method_name[] = "keys";
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
+ if (PyErr_Occurred() || raw_result.get() == nullptr) {
+ return nullptr;
+ }
return PySequence_Fast(
raw_result.get(),
"The '.keys()' method of a custom mapping returned a non-sequence.");
@@ -260,6 +265,9 @@ class ValIterator {
// Return a borrowed reference to the next element from iterable.
// Return nullptr when iteration is over.
PyObject* next() {
+ if (TF_PREDICT_FALSE(seq_ == nullptr)) {
+ return nullptr;
+ }
PyObject* element = nullptr;
if (index_ < size_) {
// Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
@@ -430,16 +438,26 @@ bool FlattenHelper(
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
bool* is_type_error) {
- PyObject* k1 = MappingKeys(dict1);
- PyObject* k2 = MappingKeys(dict2);
+ Safe_PyObjectPtr k1(MappingKeys(dict1));
+ if (PyErr_Occurred() || k1.get() == nullptr) {
+ *error_msg =
+ ("The two dictionaries don't have the same set of keys. Failed to "
+ "fetch keys.");
+ return;
+ }
+ Safe_PyObjectPtr k2(MappingKeys(dict2));
+ if (PyErr_Occurred() || k2.get() == nullptr) {
+ *error_msg =
+ ("The two dictionaries don't have the same set of keys. Failed to "
+ "fetch keys.");
+ return;
+ }
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two dictionaries don't have the same set of keys. "
"First structure has keys ",
- PyObjectToString(k1), ", while second structure has keys ",
- PyObjectToString(k2));
- Py_DECREF(k1);
- Py_DECREF(k2);
+ PyObjectToString(k1.get()), ", while second structure has keys ",
+ PyObjectToString(k2.get()));
}
// Returns true iff there were no "internal" errors. In other words,
@@ -522,7 +540,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
- if (PyDict_Check(o1)) {
+ if (PyDict_Check(o1) && PyDict_Check(o2)) {
if (PyDict_Size(o1) != PyDict_Size(o2)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
@@ -741,6 +759,11 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
string error_msg;
bool is_type_error = false;
AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+ if (PyErr_Occurred()) {
+ // Don't hide Python exceptions while checking (e.g. errors fetching keys
+ // from custom mappings).
+ return nullptr;
+ }
if (!error_msg.empty()) {
PyErr_SetString(
is_type_error ? PyExc_TypeError : PyExc_ValueError,
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index ea87744b22..7f851e3646 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1121,6 +1121,40 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator) = 0;
+ // Batched gemm with strides instead of pointer arrays.
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+ virtual bool DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) = 0;
+
// Computes a matrix-matrix product where one input matrix is Hermitian:
//
// c <- alpha * a * b + beta * c,
@@ -1990,6 +2024,38 @@ class BlasSupport {
int ldb, std::complex<double> beta, \
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \
int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, \
+ const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \
+ const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \
+ DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
+ int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \
+ int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \
+ int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, double alpha, \
+ const DeviceMemory<double> &a, int lda, int64 stride_a, \
+ const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \
+ DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ int64 stride_c, int batch_count); \
+ bool DoBlasGemmStridedBatched( \
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
+ uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
+ int ldc, int64 stride_c, int batch_count); \
bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
uint64 m, uint64 n, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 6988389f29..ab7091b3f5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx)
#if CUDA_VERSION >= 8000
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched)
#endif
#if CUDA_VERSION >= 9000
@@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#if CUDA_VERSION >= 9010
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx)
+STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx)
#endif
} // namespace wrap
@@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm(
stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
&cc_minor);
- // GPUs < sm_70 don't support Volta hardware.
+ // GPUs < sm_70 don't support tensor ops.
if (cc_major >= 7 && TensorOpMathEnabled()) {
use_tensor_ops = true;
}
@@ -2623,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched(
return status.ok();
}
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
+ int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ bool use_tensor_ops = false;
+#if CUDA_VERSION >= 9000
+ int cc_major, cc_minor;
+ if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
+ &cc_major, &cc_minor)) {
+ // GPUs < sm_70 don't support tensor ops.
+ if (cc_major >= 7 && TensorOpMathEnabled()) {
+ use_tensor_ops = true;
+ }
+#if CUDA_VERSION >= 9010
+ if (cc_major >= 5) {
+ cublasGemmAlgo_t algo =
+ (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasGemmStridedBatchedEx, stream,
+ true /* = pointer_mode_host */, true /* = err_on_failure */,
+ use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb),
+ m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a,
+ CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c),
+ CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo);
+ if (ok) {
+ return true;
+ }
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+#endif
+ }
+#endif
+ // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
+ for (int batch = 0; batch < batch_count; ++batch) {
+ const auto *a_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(a) + batch * stride_a);
+ const auto *b_matrix =
+ reinterpret_cast<const __half *>(CUDAMemory(b) + batch * stride_b);
+ auto *c_matrix =
+ reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c);
+ bool ok = DoBlasInternalImpl(
+ wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
+ true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
+ CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF,
+ lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix,
+ SE_CUDA_DATA_HALF, ldc);
+ if (!ok) {
+ LOG(ERROR) << "failed BLAS call, see log for details";
+ return false;
+ }
+ }
+ return true;
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
+ CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta,
+ CUDAMemoryMutable(c), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
+bool CUDABlas::DoBlasGemmStridedBatched(
+ Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
+ uint64 n, uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ return DoBlasInternal(
+ wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a,
+ CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta),
+ CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count);
+}
+
bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
blas::UpperLower uplo, uint64 m, uint64 n,
std::complex<float> alpha,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 1c3940e92c..725f6aeaa4 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3082,8 +3082,7 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
- // zero-initialized.
- // TODO(timshen): Add an nvbugs/ link.
+ // zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
algorithm_config.algorithm().algo_id() ==
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index dbece3adf9..f982f34b98 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/human_readable.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h"
#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/static_threadlocal.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -66,14 +67,17 @@ class CreatedContexts {
return Live()->find(context) != Live()->end();
}
- // Adds context to the live set.
+ // Adds context to the live set, or returns it if it's already present.
static CudaContext* Add(CUcontext context) {
CHECK(context != nullptr);
mutex_lock lock(mu_);
- auto cuda_context = new CudaContext(context, next_id_++);
- Live()->insert(
- std::make_pair(context, std::unique_ptr<CudaContext>(cuda_context)));
- return cuda_context;
+ auto insert_result = Live()->insert(std::make_pair(context, nullptr));
+ auto it = insert_result.first;
+ if (insert_result.second) {
+ // context was not present in the map. Add it.
+ it->second = MakeUnique<CudaContext>(context, next_id_++);
+ }
+ return it->second.get();
}
// Removes context from the live set.
@@ -427,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
*context = CreatedContexts::Add(new_context);
CHECK(*context != nullptr)
<< "success in this call must entail non-null result";
- VLOG(2) << "created context " << context << " for this thread";
+ VLOG(2) << "created or reused context " << context << " for this thread";
return port::Status::OK();
}
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index b0c061fd74..a42a469df5 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) {
}
string ToVlogString(const DeviceMemoryBase *memory) {
- return ToVlogString(*memory);
+ return memory == nullptr ? "null" : ToVlogString(*memory);
}
string ToVlogString(const Eigen::half &h) {
@@ -211,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream,
// constructing all the strings in params is expensive.
CHECK(VLOG_IS_ON(1));
- string str = port::StrCat("Called Stream::", function_name, "(");
+ string str = port::StrCat(stream->DebugStreamPointers(),
+ " Called Stream::", function_name, "(");
const char *separator = "";
for (const auto &param : params) {
port::StrAppend(&str, separator, param.first, "=", param.second);
separator = ", ";
}
- port::StrAppend(&str, ") stream=", ToVlogString(stream));
+ port::StrAppend(&str, ")");
if (VLOG_IS_ON(10)) {
port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
}
@@ -1922,37 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer(
Stream *Stream::GetOrCreateSubStream() {
mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (stream.second) {
- stream.second = false;
- return stream.first.get();
+
+ // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
+ // we encounter along the way.
+ for (int64 index = 0; index < sub_streams_.size();) {
+ std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
+ if (pair.second) {
+ // The sub_stream is reusable.
+ Stream *sub_stream = pair.first.get();
+ if (sub_stream->ok()) {
+ VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
+ << sub_stream->DebugStreamPointers();
+ pair.second = false;
+ return sub_stream;
+ }
+
+ // The stream is reusable and not ok. Streams have a monotonic state
+ // machine; the stream will remain in !ok forever. Swap it with the last
+ // stream and pop it off.
+ const int64 last = sub_streams_.size() - 1;
+ if (index != last) {
+ std::swap(pair, sub_streams_[last]);
+ }
+ sub_streams_.pop_back();
+ VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ } else {
+ // The sub_stream is not reusable, move on to the next one.
+ ++index;
}
}
+
+ // No streams are reusable; create a new stream.
sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
false);
Stream *sub_stream = sub_streams_.back().first.get();
sub_stream->Init();
CHECK(ok_) << "sub-stream failed to be initialized";
+ VLOG(1) << DebugStreamPointers() << " created new sub_stream "
+ << sub_stream->DebugStreamPointers();
return sub_stream;
}
void Stream::ReturnSubStream(Stream *sub_stream) {
mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (stream.first.get() == sub_stream) {
- // Streams have a monotonic state machine; if a stream
- // encounters an error, it will remain in an error state
- // forever. Only allow re-use of ok streams.
- //
- // TODO(toddw): Improve this mechanism, if necessary, to drop
- // failed streams completely.
- const bool ready_to_reuse = sub_stream->ok();
- stream.second = ready_to_reuse;
- return;
+
+ // Look for the sub-stream.
+ for (int64 index = 0; index < sub_streams_.size(); ++index) {
+ std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
+ if (pair.first.get() != sub_stream) {
+ continue;
+ }
+
+ // Found the sub_stream.
+ if (sub_stream->ok()) {
+ VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ pair.second = true;
+ } else {
+ // The returned stream is not ok. Streams have a monotonic state
+ // machine; the stream will remain in !ok forever. Swap it with the last
+ // stream and pop it off.
+ VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
+ << sub_stream->DebugStreamPointers();
+ const int64 last = sub_streams_.size() - 1;
+ if (index != last) {
+ std::swap(pair, sub_streams_[last]);
+ }
+ sub_streams_.pop_back();
}
+ return;
}
- LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
+
+ LOG(FATAL) << DebugStreamPointers()
+ << " did not create the returned sub-stream "
+ << sub_stream->DebugStreamPointers();
}
Stream &Stream::ThenStartTimer(Timer *t) {
@@ -1961,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) {
if (ok()) {
CheckError(parent_->StartTimer(this, t));
} else {
- LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
+ LOG(INFO) << DebugStreamPointers()
+ << " did not enqueue 'start timer': " << t;
}
return *this;
}
@@ -1972,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) {
if (ok()) {
CheckError(parent_->StopTimer(this, t));
} else {
- LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
+ LOG(INFO) << DebugStreamPointers()
+ << " did not enqueue 'stop timer': " << t;
}
return *this;
}
@@ -1985,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) {
CheckError(parent_->CreateStreamDependency(this, other));
} else {
SetError();
- LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
+ LOG(INFO) << DebugStreamPointers() << " did not wait for "
+ << other->DebugStreamPointers();
}
return *this;
}
@@ -2002,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) {
<< "at fault. Monitor for further errors.";
}
} else {
- LOG(INFO) << "stream " << this << " did not wait for an event.";
+ LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
}
return *this;
}
@@ -4685,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
scratch_allocator);
}
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<Eigen::half> &, int, int64,
+ const DeviceMemory<Eigen::half> &, int, int64, float,
+ DeviceMemory<Eigen::half> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
+ const DeviceMemory<float> &, int, int64,
+ const DeviceMemory<float> &, int, int64, float,
+ DeviceMemory<float> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
+ const DeviceMemory<double> &, int, int64,
+ const DeviceMemory<double> &, int, int64, double,
+ DeviceMemory<double> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<float>, const DeviceMemory<std::complex<float>> &,
+ int, int64, const DeviceMemory<std::complex<float>> &, int,
+ int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
+ int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
+Stream &Stream::ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count) {
+ VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
+ PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
+ PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
+ PARAM(stride_c), PARAM(batch_count));
+
+ ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ std::complex<double>, const DeviceMemory<std::complex<double>> &,
+ int, int64, const DeviceMemory<std::complex<double>> &, int,
+ int64, std::complex<double>,
+ DeviceMemory<std::complex<double>> *, int, int64, int>
+ impl;
+ return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
+ transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
+ c, ldc, stride_c, batch_count);
+}
+
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
@@ -4693,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
CheckError(rng->SetSeed(this, seed, seed_bytes));
} else {
SetError();
- LOG(INFO) << "stream " << this << " unable to initialize RNG";
+ LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
}
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not set RNG seed: " << static_cast<const void *>(seed)
<< "; bytes: " << seed_bytes;
}
@@ -4711,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4727,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4743,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4758,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4774,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4790,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform(
CheckError(rng->DoPopulateRandUniform(this, values));
} else {
SetError();
- LOG(INFO) << "stream " << this
- << " attempting to perform RNG operation using StreamExecutor "
- "without RNG support.";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform RNG operation using StreamExecutor"
+ " without RNG support.";
}
}
return *this;
@@ -4805,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
if (ok()) {
CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy device-to-host; source: " << gpu_src.opaque();
}
return *this;
@@ -4818,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
if (ok()) {
CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy host-to-device; source: " << host_src;
}
return *this;
@@ -4831,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
if (ok()) {
CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memcpy gpu-to-gpu; source: " << &gpu_src;
}
return *this;
@@ -4843,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
if (ok()) {
CheckError(parent_->MemZero(this, location, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memzero GPU location; source: " << location;
}
return *this;
@@ -4856,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
if (ok()) {
CheckError(parent_->Memset32(this, location, pattern, size));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " did not memset GPU location; source: " << location
<< "; size: " << size << "; pattern: " << std::hex << pattern;
}
@@ -5125,7 +5288,7 @@ Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
if (ok()) {
CheckError(parent_->HostCallback(this, callback));
} else {
- LOG(INFO) << "stream " << this
+ LOG(INFO) << DebugStreamPointers()
<< " was in error state before adding host callback";
}
return *this;
@@ -5141,8 +5304,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5158,8 +5322,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5174,8 +5339,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5190,8 +5356,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5207,8 +5374,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5224,8 +5392,9 @@ Stream &Stream::ThenFft(fft::Plan *plan,
CheckError(fft->DoFft(this, plan, input, output));
} else {
SetError();
- LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
- "without FFT support";
+ LOG(INFO) << DebugStreamPointers()
+ << " attempting to perform FFT operation using StreamExecutor"
+ " without FFT support";
}
}
return *this;
@@ -5252,7 +5421,7 @@ port::Status Stream::BlockHostUntilDone() {
port::Status status = port::Status(
port::error::INTERNAL,
"stream did not block host until done; was already in an error state");
- LOG(INFO) << status << " " << this;
+ LOG(INFO) << DebugStreamPointers() << " " << status;
return status;
}
@@ -5263,4 +5432,10 @@ port::Status Stream::BlockHostUntilDone() {
return error;
}
+string Stream::DebugStreamPointers() const {
+ // Relies on the ToVlogString(const void*) overload above.
+ return port::StrCat("[stream=", ToVlogString(this),
+ ",impl=", ToVlogString(implementation_.get()), "]");
+}
+
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 706442a666..4d41409fef 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -122,10 +122,14 @@ class Stream {
// Get or create a sub-stream from this stream. If there is any sub-stream in
// the pool that can be reused then just return this sub-stream. Otherwise
// create a new sub-stream.
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
// Return the sub-stream back to the host stream so that it can be reused
// later. Sub-streams that are !ok() will not be reused.
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone.
void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
// Allocate temporary memories. The stream will deallocate them when blocked
@@ -1557,6 +1561,38 @@ class Stream {
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
+ int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb,
+ int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
+ int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
+ float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
+ int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
+ double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
+ int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
+ std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ int64 stride_c, int batch_count);
+ Stream &ThenBlasGemmStridedBatched(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
+ const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
+ std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ int64 stride_c, int batch_count);
// See BlasSupport::DoBlasHemm.
Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
@@ -2019,6 +2055,9 @@ class Stream {
// with this stream.
internal::TemporaryMemoryManager *temporary_memory_manager();
+ // Returns a debugging string "[stream=0x...,impl=0x...]".
+ string DebugStreamPointers() const;
+
private:
friend class host::HostBlas; // for parent_.
friend class host::HostFft; // for parent_.
diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc
index 47dd675834..cfc051fd09 100644
--- a/tensorflow/stream_executor/stream_test.cc
+++ b/tensorflow/stream_executor/stream_test.cc
@@ -95,18 +95,18 @@ TEST_F(StreamTest, TwoSubStreams) {
EXPECT_NE(sub_stream3, sub_stream4);
}
-TEST_F(StreamTest, FailedSubStreamNotReused) {
+TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) {
std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
Stream stream(executor.get());
stream.Init();
EXPECT_TRUE(stream.ok());
- // Get a sub-stream.
+ // Get sub_stream1.
Stream* sub_stream1 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream1->ok());
- // Force an error on the stream; here we call a method that requires
- // DNN support, which we know the Host platform doesn't support.
+ // Force an error on sub_stream1; here we call a method that requires DNN
+ // support, which we know the Host platform doesn't support.
sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
EXPECT_FALSE(sub_stream1->ok());
@@ -115,20 +115,84 @@ TEST_F(StreamTest, FailedSubStreamNotReused) {
Stream* sub_stream2 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream2->ok());
- // The underlying streams should be different. They would have been
- // the same, but since we forced an error on sub_stream1, it will
- // not be re-used. Sadly we can't just check:
+ // The underlying sub_streams should be different. They would have been the
+ // same, but since we forced an error on sub_stream1, it will not be
+ // re-used. Sadly we can't just check:
// EXPECT_NE(sub_stream1, sub_stream2);
//
- // The above should hold logically, but it may fail if the new
- // stream instance allocated for sub_stream2 happens to reside in
- // the same memory address as sub_stream1.
+ // The above should hold logically, but it may fail if the new Stream instance
+ // allocated for sub_stream2 happens to reside in the same memory address as
+ // sub_stream1.
//
// The check that sub_stream2->ok() serves as a good-enough check.
- // Return sub_stream2 and get sub_stream3. The previous error on
- // sub_stream1 has no effect on these streams, and they are the
- // same.
+ // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
+ // has no effect on these streams, and they are the same.
+ stream.ReturnSubStream(sub_stream2);
+ Stream* sub_stream3 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream3->ok());
+ EXPECT_EQ(sub_stream2, sub_stream3);
+}
+
+TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) {
+ std::unique_ptr<StreamExecutor> executor = NewStreamExecutor();
+ Stream stream(executor.get());
+ stream.Init();
+ EXPECT_TRUE(stream.ok());
+
+ // Get and return sub_stream1.
+ Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream1->ok());
+ stream.ReturnSubStream(sub_stream1);
+
+ // Force an error on sub_stream1; here we call a method that requires DNN
+ // support, which we know the Host platform doesn't support.
+ //
+ // It is a bit weird to use sub_stream1 after it has already been returned. By
+ // doing this, we're simulating an asynchronous error that occurs during
+ // execution of the sub_stream, that occurs after the sub_stream is returned.
+ //
+ // E.g. the following is a common pattern of usage, where the execution of the
+ // operations enqueued onto the sub streams may occur after the streams have
+ // already been returned.
+ //
+ // void EnqueueOnSubStreams(Stream* stream) {
+ // Stream* sub_stream1 = stream.GetOrCreateSubStream();
+ // Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ // // ... enqueue some operations on the sub streams ...
+ // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2);
+ // stream.ReturnSubStream(sub_stream1);
+ // stream.ReturnSubStream(sub_stream2);
+ // }
+ //
+ // Stream* main_stream = ...;
+ // EnqueueOnSubStreams(main_stream);
+ // main_stream.BlockHostUntilDone();
+ //
+ // TODO(b/112196569): The semantics of failed sub-streams is error-prone;
+ // GetOrCreateSubStream can still return a sub-stream that has not encountered
+ // an error yet, but will encounter one in the future, based on previously
+ // enqueued operations.
+ sub_stream1->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(sub_stream1->ok());
+
+ // Get and return sub_stream2.
+ Stream* sub_stream2 = stream.GetOrCreateSubStream();
+ EXPECT_TRUE(sub_stream2->ok());
+
+ // The underlying streams should be different. They would have been the same,
+ // but since we forced an error on sub_stream1, it will not be re-used. Sadly
+ // we can't just check:
+ // EXPECT_NE(sub_stream1, sub_stream2);
+ //
+ // The above should hold logically, but it may fail if the new stream instance
+ // allocated for sub_stream2 happens to reside in the same memory address as
+ // sub_stream1.
+ //
+ // The check that sub_stream2->ok() serves as a good-enough check.
+
+ // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1
+ // has no effect on these streams, and they are the same.
stream.ReturnSubStream(sub_stream2);
Stream* sub_stream3 = stream.GetOrCreateSubStream();
EXPECT_TRUE(sub_stream3->ok());
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 45c8117304..443c582360 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -1102,6 +1102,10 @@ def tf_kernel_library(
tf_gpu_kernel_library(
name=name + "_gpu", srcs=gpu_srcs, deps=deps, **kwargs)
cuda_deps.extend([":" + name + "_gpu"])
+ kwargs["tags"] = kwargs.get("tags", []) + [
+ "req_dep=%s" % clean_dep("//tensorflow/core:gpu_lib"),
+ "req_dep=@local_config_cuda//cuda:cuda_headers",
+ ]
tf_cuda_library(
name=name,
srcs=srcs,
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
index 1f9aeb6ad6..4f0147a523 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.data.Iterator"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "initializer"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index 40e82b18b6..e579fe6a1a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -135,7 +135,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
index 65cfad77d1..6f05cdd093 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
@@ -140,7 +140,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index 85f7c2bfed..56914e1746 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -135,7 +135,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
index 6a83129f7d..4c1c54001d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -140,7 +140,7 @@ tf_class {
}
member_method {
name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/ci_build/builds/android.sh b/tensorflow/tools/ci_build/builds/android.sh
index d81793efe0..7c3e308229 100755
--- a/tensorflow/tools/ci_build/builds/android.sh
+++ b/tensorflow/tools/ci_build/builds/android.sh
@@ -26,13 +26,19 @@ configure_android_workspace
# android_full.sh
echo "========== TensorFlow Demo Build Test =========="
+TARGETS=
+TARGETS+=" //tensorflow/examples/android:tensorflow_demo"
+# Also build the Eager Runtime so it remains compatible with Android for the
+# benefits of clients like TensorFlow Lite. For now it is enough to build only
+# :execute, which what TF Lite needs.
+TARGETS+=" //tensorflow/core/common_runtime/eager:execute"
# Enable sandboxing so that zip archives don't get incorrectly packaged
# in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334)
# TODO(gunan): remove extra flags once sandboxing is enabled for all builds.
bazel --bazelrc=/dev/null build \
--compilation_mode=opt --cxxopt=-std=c++11 --fat_apk_cpu=x86_64 \
--spawn_strategy=sandboxed --genrule_strategy=sandboxed \
- //tensorflow/examples/android:tensorflow_demo
+ ${TARGETS}
echo "========== Makefile Build Test =========="
# Test Makefile build just to make sure it still works.
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index 883bb93647..fef121ab5a 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -314,7 +314,10 @@ create_activate_virtualenv_and_install_tensorflow() {
# Upgrade pip so it supports tags such as cp27mu, manylinux1 etc.
echo "Upgrade pip in virtualenv"
- pip install --upgrade pip==9.0.1
+
+ # NOTE: pip install --upgrade pip leads to a documented TLS issue for
+ # some versions in python
+ curl https://bootstrap.pypa.io/get-pip.py | python
# Force tensorflow reinstallation. Otherwise it may not get installed from
# last build if it had the same version number as previous build.
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 29680e6882..bbaf59c69a 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -97,7 +97,8 @@ fi
# TF_BUILD_APPEND_ARGUMENTS any user supplied args.
BAZEL_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py \
--build_tests_only -k --test_tag_filters=${PIP_TEST_FILTER_TAG} \
- --test_timeout 300,450,1200,3600 ${TF_BUILD_APPEND_ARGUMENTS}"
+ --test_timeout 300,450,1200,3600 ${TF_BUILD_APPEND_ARGUMENTS} \
+ --test_output=errors"
BAZEL_TEST_TARGETS="//${PIP_TEST_PREFIX}/tensorflow/contrib/... \
//${PIP_TEST_PREFIX}/tensorflow/python/... \
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 221b5b80fb..c3c537328f 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -61,11 +61,11 @@ rm -rf /usr/lib/python3/dist-packages/six*
# https://github.com/tensorflow/tensorflow/issues/6968
# This workaround isn't needed for Ubuntu 16.04 or later.
if $(cat /etc/*-release | grep -q 14.04); then
- pip2 install --no-binary=:all: --upgrade numpy==1.12.0
- pip3 install --no-binary=:all: --upgrade numpy==1.12.0
+ pip2 install --no-binary=:all: --upgrade numpy==1.14.5
+ pip3 install --no-binary=:all: --upgrade numpy==1.14.5
else
- pip2 install --upgrade numpy==1.12.0
- pip3 install --upgrade numpy==1.12.0
+ pip2 install --upgrade numpy==1.14.5
+ pip3 install --upgrade numpy==1.14.5
fi
pip2 install scipy==0.18.1
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 45a30c6e82..b6f5de57c9 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -58,7 +58,7 @@ rm -rf /usr/lib/python3/dist-packages/six*
# numpy needs to be installed from source to fix segfaults. See:
# https://github.com/tensorflow/tensorflow/issues/6968
# This workaround isn't needed for Ubuntu 16.04 or later.
-pip3.5 install --no-binary=:all: --upgrade numpy==1.12.0
+pip3.5 install --no-binary=:all: --upgrade numpy==1.14.5
pip3.5 install scipy==0.18.1
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index d66b2aa18a..8868664132 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -70,7 +70,7 @@ rm -rf /usr/lib/python3/dist-packages/six*
# numpy needs to be installed from source to fix segfaults. See:
# https://github.com/tensorflow/tensorflow/issues/6968
# This workaround isn't needed for Ubuntu 16.04 or later.
-pip3 install --no-binary=:all: --upgrade numpy==1.12.0
+pip3 install --no-binary=:all: --upgrade numpy==1.14.5
pip3 install scipy==0.18.1
@@ -101,7 +101,7 @@ pip3 install --upgrade termcolor
pip3 install --upgrade setuptools==39.1.0
# Keras
-pip3.5 install keras_applications==1.0.2
-pip3.5 install keras_preprocessing==1.0.1
+pip3 install keras_applications==1.0.2
+pip3 install keras_preprocessing==1.0.1
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh)
diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py
index e0acead919..b40e4155df 100644
--- a/tensorflow/tools/common/public_api.py
+++ b/tensorflow/tools/common/public_api.py
@@ -50,6 +50,7 @@ class PublicAPIVisitor(object):
# Each entry maps a module path to a name to ignore in traversal.
self._do_not_descend_map = {
'tf': [
+ 'compiler',
'core',
'examples',
'flags', # Don't add flags
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index a3ff8211e3..bf06214009 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -30,7 +30,7 @@ RUN pip --no-cache-dir install \
ipykernel \
jupyter \
matplotlib \
- numpy \
+ numpy==1.14.5 \
pandas \
scipy \
sklearn \
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index f7fe4119da..6552588fac 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -35,7 +35,7 @@ RUN pip --no-cache-dir install \
jupyter \
matplotlib \
mock \
- numpy \
+ numpy==1.14.5 \
scipy \
sklearn \
pandas \
@@ -76,7 +76,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
# TODO(craigcitro): Don't install the pip package, since it makes it
# more difficult to experiment with local changes. Instead, just add
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 340f96df48..f4c83f85d4 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -51,7 +51,7 @@ RUN pip --no-cache-dir install \
jupyter \
matplotlib \
mock \
- numpy \
+ numpy==1.14.5 \
scipy \
sklearn \
pandas \
@@ -92,7 +92,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index c85641b383..f0c7118ecb 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -3,7 +3,7 @@ FROM ubuntu:16.04
LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
# These parameters can be overridden by parameterized_docker_build.sh
-ARG TF_BUILD_VERSION=r1.9
+ARG TF_BUILD_VERSION=r1.10
ARG PYTHON="python"
ARG PYTHON3_DEV=""
ARG WHL_DIR="/tmp/pip"
@@ -73,7 +73,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.14.1
+ENV BAZEL_VERSION 0.15.0
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index 28d4371da3..5ec1e60f00 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -38,7 +38,7 @@ RUN pip --no-cache-dir install \
ipykernel \
jupyter \
matplotlib \
- numpy \
+ numpy==1.14.5 \
pandas \
scipy \
sklearn \
diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md
index 525f2995ce..a286e8a212 100644
--- a/tensorflow/tools/docker/README.md
+++ b/tensorflow/tools/docker/README.md
@@ -87,8 +87,10 @@ export TF_DOCKER_BUILD_IS_DEVEL=NO
export TF_DOCKER_BUILD_TYPE=CPU
export TF_DOCKER_BUILD_PYTHON_VERSION=PYTHON2
-export NIGHTLY_VERSION="1.head"
-export TF_DOCKER_BUILD_CENTRAL_PIP=$(echo ${TF_DOCKER_BUILD_PYTHON_VERSION} | sed s^PYTHON2^http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=${TF_DOCKER_BUILD_PYTHON_VERSION},label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp27-cp27mu-manylinux1_x86_64.whl^ | sed s^PYTHON3^http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp35-cp35m-manylinux1_x86_64.whl^)
+pip download --no-deps tf-nightly
+
+export TF_DOCKER_BUILD_CENTRAL_PIP=$(ls tf_nightly*.whl)
+export TF_DOCKER_BUILD_CENTRAL_PIP_IS_LOCAL=1
tensorflow/tools/docker/parameterized_docker_build.sh
```
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index fdc2862a12..06ee2307e5 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -70,6 +70,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 1f4c3d47bf..07ed7f5195 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -45,13 +45,13 @@ DOCLINES = __doc__.split('\n')
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.9.0'
+_VERSION = '1.10.0-rc1'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
- 'numpy >= 1.13.3',
+ 'numpy >= 1.13.3, <= 1.14.5',
'six >= 1.10.0',
'protobuf >= 3.6.0',
'setuptools <= 39.1.0',
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index c145cfe72e..702698abed 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -6,7 +6,6 @@ load("//third_party:nccl/nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
-
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure")
@@ -15,926 +14,928 @@ load("//third_party:repo.bzl", "tf_http_archive")
load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain")
load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external")
load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external")
-load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl",
- "def_file_filter_configure")
-
+load(
+ "//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl",
+ "def_file_filter_configure",
+)
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
- return str(Label(dep))
+ return str(Label(dep))
# If TensorFlow is linked as a submodule.
# path_prefix is no longer used.
# tf_repo_name is thought to be under consideration.
-def tf_workspace(path_prefix="", tf_repo_name=""):
- # Note that we check the minimum bazel version in WORKSPACE.
- clang6_configure(name="local_config_clang6")
- cc_download_clang_toolchain(name="local_config_download_clang")
- cuda_configure(name="local_config_cuda")
- tensorrt_configure(name="local_config_tensorrt")
- nccl_configure(name="local_config_nccl")
- git_configure(name="local_config_git")
- sycl_configure(name="local_config_sycl")
- syslibs_configure(name="local_config_syslibs")
- python_configure(name="local_config_python")
-
- # For windows bazel build
- # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows.
- def_file_filter_configure(name = "local_config_def_file_filter")
-
- # Point //external/local_config_arm_compiler to //external/arm_compiler
- arm_compiler_configure(
- name="local_config_arm_compiler",
- remote_config_repo="../arm_compiler",
- build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD"))
-
- mkl_repository(
- name = "mkl_linux",
- urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz"
- ],
- sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725",
- strip_prefix = "mklml_lnx_2018.0.3.20180406",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD")
- )
- mkl_repository(
- name = "mkl_windows",
- urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip",
- "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip"
- ],
- sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694",
- strip_prefix = "mklml_win_2018.0.3.20180406",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD")
- )
- mkl_repository(
- name = "mkl_darwin",
- urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz"
- ],
- sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b",
- strip_prefix = "mklml_mac_2018.0.3.20180406",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD")
- )
-
- if path_prefix:
- print("path_prefix was specified to tf_workspace but is no longer used " +
- "and will be removed in the future.")
-
- tf_http_archive(
- name = "mkl_dnn",
- urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/v0.14.tar.gz",
- "https://github.com/intel/mkl-dnn/archive/v0.14.tar.gz",
- ],
- sha256 = "efebc53882856afec86457a2da644693f5d59c68772d41d640d6b60a8efc4eb0",
- strip_prefix = "mkl-dnn-0.14",
- build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
- )
-
- tf_http_archive(
- name = "com_google_absl",
- urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz",
- ],
- sha256 = "1273a1434ced93bc3e703a48c5dced058c95e995c8c009e9bdcb24a69e2180e9",
- strip_prefix = "abseil-cpp-9613678332c976568272c8f4a78631a29159271d",
- build_file = clean_dep("//third_party:com_google_absl.BUILD"),
- )
-
- tf_http_archive(
- name = "eigen_archive",
- urls = [
- "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
- "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
- ],
- sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
- strip_prefix = "eigen-eigen-fd6845384b86",
- build_file = clean_dep("//third_party:eigen.BUILD"),
- )
-
- tf_http_archive(
- name = "arm_compiler",
- sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969",
- strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf",
- urls = [
- "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
- # Please uncomment me, when the next upgrade happens. Then
- # remove the whitelist entry in third_party/repo.bzl.
- # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
- ],
- build_file = clean_dep("//:arm_compiler.BUILD"),
- )
-
- tf_http_archive(
- name = "libxsmm_archive",
- urls = [
- "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz",
- "https://github.com/hfp/libxsmm/archive/1.9.tar.gz",
- ],
- sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
- strip_prefix = "libxsmm-1.9",
- build_file = clean_dep("//third_party:libxsmm.BUILD"),
- )
-
- tf_http_archive(
- name = "ortools_archive",
- urls = [
- "https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz",
- "https://github.com/google/or-tools/archive/v6.7.2.tar.gz",
- ],
- sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
- strip_prefix = "or-tools-6.7.2/src",
- build_file = clean_dep("//third_party:ortools.BUILD"),
- )
-
- tf_http_archive(
- name = "com_googlesource_code_re2",
- urls = [
- "https://mirror.bazel.build/github.com/google/re2/archive/2018-04-01.tar.gz",
- "https://github.com/google/re2/archive/2018-04-01.tar.gz",
-
- ],
- sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912",
- strip_prefix = "re2-2018-04-01",
- system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
- )
-
- tf_http_archive(
- name = "com_github_googlecloudplatform_google_cloud_cpp",
- urls = [
- "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
- "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
- ],
- sha256 = "a34f3c50b237686dc870b13baaa6a5836ce3473f2f2a02717299f0ff318372db",
- strip_prefix = "google-cloud-cpp-f875700a023bdd706333cde45aee8758b272c357",
- )
-
- tf_http_archive(
- name = "com_github_googleapis_googleapis",
- urls = [
- "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
- "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
- ],
- sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
- strip_prefix="googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
- build_file = clean_dep("//third_party:googleapis.BUILD"),
- )
-
- tf_http_archive(
- name = "gemmlowp",
- urls = [
- "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
- "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
- ],
- sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
- strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
- )
-
- tf_http_archive(
- name = "farmhash_archive",
- urls = [
- "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
- "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
- ],
- sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
- strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
- build_file = clean_dep("//third_party:farmhash.BUILD"),
- )
-
- tf_http_archive(
- name = "highwayhash",
- urls = [
- "http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
- "https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
- ],
- sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
- strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
- build_file = clean_dep("//third_party:highwayhash.BUILD"),
- )
-
- tf_http_archive(
- name = "nasm",
- urls = [
- "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
- "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
- "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
- ],
- sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
- strip_prefix = "nasm-2.13.03",
- build_file = clean_dep("//third_party:nasm.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
- )
-
- tf_http_archive(
- name = "jpeg",
- urls = [
- "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
- "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
- ],
- sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde",
- strip_prefix = "libjpeg-turbo-1.5.3",
- build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
- )
-
- tf_http_archive(
- name = "png_archive",
- urls = [
- "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
- "https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
- ],
- sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
- strip_prefix = "libpng-1.6.34",
- build_file = clean_dep("//third_party:png.BUILD"),
- patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
- system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
- )
-
- tf_http_archive(
- name = "org_sqlite",
- urls = [
- "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
- "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
- ],
- sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
- strip_prefix = "sqlite-amalgamation-3240000",
- build_file = clean_dep("//third_party:sqlite.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
- )
-
- tf_http_archive(
- name = "gif_archive",
- urls = [
- "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
- "http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
- ],
- sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
- strip_prefix = "giflib-5.1.4",
- build_file = clean_dep("//third_party:gif.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
- )
-
- tf_http_archive(
- name = "six_archive",
- urls = [
- "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
- "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
- ],
- sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
- strip_prefix = "six-1.10.0",
- build_file = clean_dep("//third_party:six.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
- )
-
- tf_http_archive(
- name = "astor_archive",
- urls = [
- "https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
- "https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
- ],
- sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
- strip_prefix = "astor-0.6.2",
- build_file = clean_dep("//third_party:astor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
- )
-
- tf_http_archive(
- name = "gast_archive",
- urls = [
- "https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
- "https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
- ],
- sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
- strip_prefix = "gast-0.2.0",
- build_file = clean_dep("//third_party:gast.BUILD"),
- )
-
- tf_http_archive(
- name = "termcolor_archive",
- urls = [
- "https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
- "https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
- ],
- sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
- strip_prefix = "termcolor-1.1.0",
- build_file = clean_dep("//third_party:termcolor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
- )
-
- tf_http_archive(
- name = "absl_py",
- urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
- "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
- ],
- sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
- strip_prefix = "abseil-py-pypi-v0.2.2",
- )
-
- tf_http_archive(
- name = "org_python_pypi_backports_weakref",
- urls = [
- "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
- "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
- ],
- sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
- strip_prefix = "backports.weakref-1.0rc1/src",
- build_file = clean_dep("//third_party:backports_weakref.BUILD"),
- )
-
- filegroup_external(
- name = "org_python_license",
- licenses = ["notice"], # Python 2.0
- sha256_urls = {
- "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [
- "https://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt",
- "https://docs.python.org/2.7/_sources/license.txt",
- ],
- },
- )
-
- tf_http_archive(
- name = "protobuf_archive",
- urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
- "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
- ],
- sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
- strip_prefix = "protobuf-3.6.0",
- )
-
- # We need to import the protobuf library under the names com_google_protobuf
- # and com_google_protobuf_cc to enable proto_library support in bazel.
- # Unfortunately there is no way to alias http_archives at the moment.
- tf_http_archive(
- name = "com_google_protobuf",
- urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
- "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
- ],
- sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
- strip_prefix = "protobuf-3.6.0",
- )
-
- tf_http_archive(
- name = "com_google_protobuf_cc",
- urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
- "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
- ],
- sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
- strip_prefix = "protobuf-3.6.0",
- )
-
- tf_http_archive(
- name = "nsync",
- urls = [
- "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz",
- "https://github.com/google/nsync/archive/1.20.0.tar.gz",
- ],
- sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd",
- strip_prefix = "nsync-1.20.0",
- )
-
- tf_http_archive(
- name = "com_google_googletest",
- urls = [
- "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip",
- "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip",
- ],
- sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d",
- strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6",
- )
-
- tf_http_archive(
- name = "com_github_gflags_gflags",
- urls = [
- "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
- "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
- ],
- sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
- strip_prefix = "gflags-2.2.1",
- )
-
- tf_http_archive(
- name = "pcre",
- sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5",
- urls = [
- "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
- "http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
- ],
- strip_prefix = "pcre-8.42",
- build_file = clean_dep("//third_party:pcre.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
- )
-
- tf_http_archive(
- name = "swig",
- sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453",
- urls = [
- "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
- "http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
- "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
- ],
- strip_prefix = "swig-3.0.8",
- build_file = clean_dep("//third_party:swig.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
- )
-
- tf_http_archive(
- name = "curl",
- sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
- urls = [
- "https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz",
- "https://curl.haxx.se/download/curl-7.60.0.tar.gz",
- ],
- strip_prefix = "curl-7.60.0",
- build_file = clean_dep("//third_party:curl.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
- )
-
- tf_http_archive(
- name = "grpc",
- urls = [
- "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz",
- "https://github.com/grpc/grpc/archive/v1.13.0.tar.gz",
- ],
- sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
- strip_prefix = "grpc-1.13.0",
- system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
- )
-
- tf_http_archive(
- name = "linenoise",
- sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
- urls = [
- "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
- "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
- ],
- strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
- build_file = clean_dep("//third_party:linenoise.BUILD"),
- )
-
- # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror.
- # Switch to an official source of snapshots if/when possible.
- tf_http_archive(
- name = "llvm",
- urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428.tar.gz",
- ],
- sha256 = "c6cbb21acd46e3e00faa8c379595ecffb99ef77622da17f29371db2bfad1d3d3",
- strip_prefix = "llvm-7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428",
- build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
- )
-
- tf_http_archive(
- name = "lmdb",
- urls = [
- "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
- "https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
- ],
- sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
- strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
- build_file = clean_dep("//third_party:lmdb.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
- )
-
- tf_http_archive(
- name = "jsoncpp_git",
- urls = [
- "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
- "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
- ],
- sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
- strip_prefix = "jsoncpp-1.8.4",
- build_file = clean_dep("//third_party:jsoncpp.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
- )
-
- tf_http_archive(
- name = "boringssl",
- urls = [
- "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
- "https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
- ],
- sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3",
- strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778",
- )
-
- tf_http_archive(
- name = "zlib_archive",
- urls = [
- "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
- "https://zlib.net/zlib-1.2.11.tar.gz",
- ],
- sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
- strip_prefix = "zlib-1.2.11",
- build_file = clean_dep("//third_party:zlib.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
- )
-
- tf_http_archive(
- name = "fft2d",
- urls = [
- "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
- "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
- ],
- sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
- build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
- )
-
- tf_http_archive(
- name = "snappy",
- urls = [
- "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz",
- "https://github.com/google/snappy/archive/1.1.7.tar.gz",
- ],
- sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
- strip_prefix = "snappy-1.1.7",
- build_file = clean_dep("//third_party:snappy.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
- )
-
- tf_http_archive(
- name = "nccl_archive",
- urls = [
- "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
- "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
- ],
- sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
- strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
- build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
- )
-
- tf_http_archive(
- name = "kafka",
- urls = [
- "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
- "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
- ],
- sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c",
- strip_prefix = "librdkafka-0.11.4",
- build_file = clean_dep("//third_party:kafka/BUILD"),
- patch_file = clean_dep("//third_party/kafka:config.patch"),
- )
-
- tf_http_archive(
- name = "aws",
- urls = [
- "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
- "https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
- ],
- sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
- strip_prefix = "aws-sdk-cpp-1.3.15",
- build_file = clean_dep("//third_party:aws.BUILD"),
- )
-
- java_import_external(
- name = "junit",
- jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a",
- jar_urls = [
- "https://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
- "http://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
- "http://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar",
- ],
- licenses = ["reciprocal"], # Common Public License Version 1.0
- testonly_ = True,
- deps = ["@org_hamcrest_core"],
- )
-
- java_import_external(
- name = "org_hamcrest_core",
- jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9",
- jar_urls = [
- "https://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
- "http://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
- "http://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
- ],
- licenses = ["notice"], # New BSD License
- testonly_ = True,
- )
-
- tf_http_archive(
- name = "jemalloc",
- urls = [
- "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
- "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
- ],
- sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
- strip_prefix = "jemalloc-4.4.0",
- build_file = clean_dep("//third_party:jemalloc.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
- )
-
- java_import_external(
- name = "com_google_testing_compile",
- jar_sha256 = "edc180fdcd9f740240da1a7a45673f46f59c5578d8cd3fbc912161f74b5aebb8",
- jar_urls = [
- "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar",
- "http://repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar",
- ],
- licenses = ["notice"], # New BSD License
- testonly_ = True,
- deps = ["@com_google_guava", "@com_google_truth"],
- )
-
- java_import_external(
- name = "com_google_truth",
- jar_sha256 = "032eddc69652b0a1f8d458f999b4a9534965c646b8b5de0eba48ee69407051df",
- jar_urls = [
- "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar",
- "http://repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar",
- ],
- licenses = ["notice"], # Apache 2.0
- testonly_ = True,
- deps = ["@com_google_guava"],
- )
-
- java_import_external(
- name = "org_checkerframework_qual",
- jar_sha256 = "a17501717ef7c8dda4dba73ded50c0d7cde440fd721acfeacbf19786ceac1ed6",
- jar_urls = [
- "http://mirror.bazel.build/repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar",
- "http://repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar",
- ],
- licenses = ["notice"], # Apache 2.0
- )
-
- java_import_external(
- name = "com_squareup_javapoet",
- jar_sha256 = "5bb5abdfe4366c15c0da3332c57d484e238bd48260d6f9d6acf2b08fdde1efea",
- jar_urls = [
- "http://mirror.bazel.build/repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar",
- "http://repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar",
- ],
- licenses = ["notice"], # Apache 2.0
- )
-
- tf_http_archive(
- name = "com_google_pprof",
- urls = [
- "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
- "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
- ],
- sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
- strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
- build_file = clean_dep("//third_party:pprof.BUILD"),
- )
-
- tf_http_archive(
- name = "cub_archive",
- urls = [
- "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
- "https://github.com/NVlabs/cub/archive/1.8.0.zip",
- ],
- sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
- strip_prefix = "cub-1.8.0",
- build_file = clean_dep("//third_party:cub.BUILD"),
- )
-
- tf_http_archive(
- name = "cython",
- sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
- urls = [
- "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
- "https://github.com/cython/cython/archive/0.28.4.tar.gz",
- ],
- strip_prefix = "cython-0.28.4",
- build_file = clean_dep("//third_party:cython.BUILD"),
- delete = ["BUILD.bazel"],
- system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
- )
-
- tf_http_archive(
- name = "bazel_toolchains",
- urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
- ],
- strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
- sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
- )
-
- tf_http_archive(
- name = "arm_neon_2_x86_sse",
- sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
- strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",
- urls = [
- "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
- "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
- ],
- build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
- )
-
- tf_http_archive(
- name = "flatbuffers",
- strip_prefix = "flatbuffers-1.9.0",
- sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3",
- urls = [
- "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
- "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
- ],
- build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:flatbuffers.BUILD"),
- )
-
- native.new_http_archive(
- name = "double_conversion",
- urls = [
- "https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
- ],
- sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
- strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
- build_file = clean_dep("//third_party:double_conversion.BUILD")
- )
-
- tf_http_archive(
- name = "tflite_mobilenet",
- sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
- urls = [
- "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
- ],
- build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
- )
-
- tf_http_archive(
- name = "tflite_mobilenet_ssd",
- sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0",
- urls = [
- "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
- ],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
- )
- tf_http_archive(
- name = "tflite_mobilenet_ssd_quant",
- sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
- urls = ["https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
- ],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
- )
-
- tf_http_archive(
- name = "tflite_conv_actions_frozen",
- sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
- urls = [
- "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
- ],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
- )
-
- tf_http_archive(
- name = "tflite_smartreply",
- sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
- urls = [
- "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
- "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip"
- ],
- build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
- )
-
- tf_http_archive(
- name = "tflite_ovic_testdata",
- sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c",
- urls = [
- "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
- "https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
- ],
- build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
- strip_prefix = "ovic",
- )
-
- tf_http_archive(
- name = "build_bazel_rules_android",
- sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
- urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
- "https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
- ],
- strip_prefix = "rules_android-0.1.1",
- )
-
- tf_http_archive(
- name = "ngraph",
- urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
- "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
- ],
- sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
- strip_prefix = "ngraph-0.5.0",
- build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
- )
-
- tf_http_archive(
- name = "nlohmann_json_lib",
- urls = [
- "https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
- "https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
- ],
- sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
- strip_prefix = "json-3.1.1",
- build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
- )
-
- tf_http_archive(
- name = "ngraph_tf",
- urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
- "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz"
- ],
- sha256 = "7919332cb15120101c3e05c1b969a5e029a6411581312583c8f80b6aaaa83072",
- strip_prefix = "ngraph-tf-0.3.0-rc1",
- build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
- )
-
- ##############################################################################
- # BIND DEFINITIONS
- #
- # Please do not add bind() definitions unless we have no other choice.
- # If that ends up being the case, please leave a comment explaining
- # why we can't depend on the canonical build target.
-
- # gRPC wants a cares dependency but its contents is not actually
- # important since we have set GRPC_ARES=0 in tools/bazel.rc
- native.bind(
- name = "cares",
- actual = "@grpc//third_party/nanopb:nanopb",
- )
-
- # Needed by Protobuf
- native.bind(
- name = "grpc_cpp_plugin",
- actual = "@grpc//:grpc_cpp_plugin",
- )
- native.bind(
- name = "grpc_python_plugin",
- actual = "@grpc//:grpc_python_plugin",
- )
-
- native.bind(
- name = "grpc_lib",
- actual = "@grpc//:grpc++",
- )
-
- native.bind(
- name = "grpc_lib_unsecure",
- actual = "@grpc//:grpc++_unsecure",
- )
-
- # Needed by gRPC
- native.bind(
- name = "libssl",
- actual = "@boringssl//:ssl",
- )
-
- # Needed by gRPC
- native.bind(
- name = "nanopb",
- actual = "@grpc//third_party/nanopb:nanopb",
- )
-
- # Needed by gRPC
- native.bind(
- name = "protobuf",
- actual = "@protobuf_archive//:protobuf",
- )
-
- # gRPC expects //external:protobuf_clib and //external:protobuf_compiler
- # to point to Protobuf's compiler library.
- native.bind(
- name = "protobuf_clib",
- actual = "@protobuf_archive//:protoc_lib",
- )
-
- # Needed by gRPC
- native.bind(
- name = "protobuf_headers",
- actual = "@protobuf_archive//:protobuf_headers",
- )
-
- # Needed by Protobuf
- native.bind(
- name = "python_headers",
- actual = clean_dep("//third_party/python_runtime:headers"),
- )
-
- # Needed by Protobuf
- native.bind(
- name = "six",
- actual = "@six_archive//:six",
- )
-
- # Needed by gRPC
- native.bind(
- name = "zlib",
- actual = "@zlib_archive//:zlib",
- )
+def tf_workspace(path_prefix = "", tf_repo_name = ""):
+ # Note that we check the minimum bazel version in WORKSPACE.
+ clang6_configure(name = "local_config_clang6")
+ cc_download_clang_toolchain(name = "local_config_download_clang")
+ cuda_configure(name = "local_config_cuda")
+ tensorrt_configure(name = "local_config_tensorrt")
+ nccl_configure(name = "local_config_nccl")
+ git_configure(name = "local_config_git")
+ sycl_configure(name = "local_config_sycl")
+ syslibs_configure(name = "local_config_syslibs")
+ python_configure(name = "local_config_python")
+
+ # For windows bazel build
+ # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows.
+ def_file_filter_configure(name = "local_config_def_file_filter")
+
+ # Point //external/local_config_arm_compiler to //external/arm_compiler
+ arm_compiler_configure(
+ name = "local_config_arm_compiler",
+ remote_config_repo = "../arm_compiler",
+ build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD"),
+ )
+
+ mkl_repository(
+ name = "mkl_linux",
+ urls = [
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz",
+ ],
+ sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725",
+ strip_prefix = "mklml_lnx_2018.0.3.20180406",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ )
+ mkl_repository(
+ name = "mkl_windows",
+ urls = [
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip",
+ ],
+ sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694",
+ strip_prefix = "mklml_win_2018.0.3.20180406",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ )
+ mkl_repository(
+ name = "mkl_darwin",
+ urls = [
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz",
+ ],
+ sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b",
+ strip_prefix = "mklml_mac_2018.0.3.20180406",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ )
+
+ if path_prefix:
+ print("path_prefix was specified to tf_workspace but is no longer used " +
+ "and will be removed in the future.")
+
+ tf_http_archive(
+ name = "mkl_dnn",
+ urls = [
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz",
+ "https://github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz",
+ ],
+ sha256 = "da1f27f92453a65331197dd8e4992e810fb7b1c4e0b902a1da5611592df2b633",
+ strip_prefix = "mkl-dnn-0c1cf54b63732e5a723c5670f66f6dfb19b64d20",
+ build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "com_google_absl",
+ urls = [
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz",
+ ],
+ sha256 = "1273a1434ced93bc3e703a48c5dced058c95e995c8c009e9bdcb24a69e2180e9",
+ strip_prefix = "abseil-cpp-9613678332c976568272c8f4a78631a29159271d",
+ build_file = clean_dep("//third_party:com_google_absl.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "eigen_archive",
+ urls = [
+ "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
+ "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
+ ],
+ sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
+ strip_prefix = "eigen-eigen-fd6845384b86",
+ build_file = clean_dep("//third_party:eigen.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "arm_compiler",
+ sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969",
+ strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf",
+ urls = [
+ "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
+ # Please uncomment me, when the next upgrade happens. Then
+ # remove the whitelist entry in third_party/repo.bzl.
+ # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
+ ],
+ build_file = clean_dep("//:arm_compiler.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "libxsmm_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz",
+ "https://github.com/hfp/libxsmm/archive/1.9.tar.gz",
+ ],
+ sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
+ strip_prefix = "libxsmm-1.9",
+ build_file = clean_dep("//third_party:libxsmm.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "ortools_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz",
+ "https://github.com/google/or-tools/archive/v6.7.2.tar.gz",
+ ],
+ sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
+ strip_prefix = "or-tools-6.7.2/src",
+ build_file = clean_dep("//third_party:ortools.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "com_googlesource_code_re2",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/re2/archive/2018-04-01.tar.gz",
+ "https://github.com/google/re2/archive/2018-04-01.tar.gz",
+ ],
+ sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912",
+ strip_prefix = "re2-2018-04-01",
+ system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "com_github_googlecloudplatform_google_cloud_cpp",
+ urls = [
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
+ ],
+ sha256 = "a34f3c50b237686dc870b13baaa6a5836ce3473f2f2a02717299f0ff318372db",
+ strip_prefix = "google-cloud-cpp-f875700a023bdd706333cde45aee8758b272c357",
+ )
+
+ tf_http_archive(
+ name = "com_github_googleapis_googleapis",
+ urls = [
+ "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
+ "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
+ ],
+ sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
+ strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
+ build_file = clean_dep("//third_party:googleapis.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "gemmlowp",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
+ "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
+ ],
+ sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
+ strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
+ )
+
+ tf_http_archive(
+ name = "farmhash_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
+ "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
+ ],
+ sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
+ strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
+ build_file = clean_dep("//third_party:farmhash.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "highwayhash",
+ urls = [
+ "http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
+ "https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
+ ],
+ sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
+ strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
+ build_file = clean_dep("//third_party:highwayhash.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "nasm",
+ urls = [
+ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
+ "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ ],
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
+ build_file = clean_dep("//third_party:nasm.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "jpeg",
+ urls = [
+ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
+ "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
+ ],
+ sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde",
+ strip_prefix = "libjpeg-turbo-1.5.3",
+ build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "png_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
+ "https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
+ ],
+ sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
+ strip_prefix = "libpng-1.6.34",
+ build_file = clean_dep("//third_party:png.BUILD"),
+ patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
+ system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "org_sqlite",
+ urls = [
+ "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
+ "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
+ ],
+ sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
+ strip_prefix = "sqlite-amalgamation-3240000",
+ build_file = clean_dep("//third_party:sqlite.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "gif_archive",
+ urls = [
+ "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
+ "http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
+ ],
+ sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
+ strip_prefix = "giflib-5.1.4",
+ build_file = clean_dep("//third_party:gif.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "six_archive",
+ urls = [
+ "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
+ "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
+ ],
+ sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
+ strip_prefix = "six-1.10.0",
+ build_file = clean_dep("//third_party:six.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "astor_archive",
+ urls = [
+ "https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
+ "https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
+ ],
+ sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
+ strip_prefix = "astor-0.6.2",
+ build_file = clean_dep("//third_party:astor.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "gast_archive",
+ urls = [
+ "https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
+ "https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
+ ],
+ sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
+ strip_prefix = "gast-0.2.0",
+ build_file = clean_dep("//third_party:gast.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "termcolor_archive",
+ urls = [
+ "https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
+ "https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
+ ],
+ sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
+ strip_prefix = "termcolor-1.1.0",
+ build_file = clean_dep("//third_party:termcolor.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "absl_py",
+ urls = [
+ "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ ],
+ sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
+ strip_prefix = "abseil-py-pypi-v0.2.2",
+ )
+
+ tf_http_archive(
+ name = "org_python_pypi_backports_weakref",
+ urls = [
+ "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
+ "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
+ ],
+ sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
+ strip_prefix = "backports.weakref-1.0rc1/src",
+ build_file = clean_dep("//third_party:backports_weakref.BUILD"),
+ )
+
+ filegroup_external(
+ name = "org_python_license",
+ licenses = ["notice"], # Python 2.0
+ sha256_urls = {
+ "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [
+ "https://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt",
+ "https://docs.python.org/2.7/_sources/license.txt",
+ ],
+ },
+ )
+
+ tf_http_archive(
+ name = "protobuf_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ ],
+ sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
+ strip_prefix = "protobuf-3.6.0",
+ )
+
+ # We need to import the protobuf library under the names com_google_protobuf
+ # and com_google_protobuf_cc to enable proto_library support in bazel.
+ # Unfortunately there is no way to alias http_archives at the moment.
+ tf_http_archive(
+ name = "com_google_protobuf",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ ],
+ sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
+ strip_prefix = "protobuf-3.6.0",
+ )
+
+ tf_http_archive(
+ name = "com_google_protobuf_cc",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ ],
+ sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
+ strip_prefix = "protobuf-3.6.0",
+ )
+
+ tf_http_archive(
+ name = "nsync",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz",
+ "https://github.com/google/nsync/archive/1.20.0.tar.gz",
+ ],
+ sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd",
+ strip_prefix = "nsync-1.20.0",
+ )
+
+ tf_http_archive(
+ name = "com_google_googletest",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip",
+ "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip",
+ ],
+ sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d",
+ strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6",
+ )
+
+ tf_http_archive(
+ name = "com_github_gflags_gflags",
+ urls = [
+ "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
+ "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
+ ],
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
+ )
+
+ tf_http_archive(
+ name = "pcre",
+ sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5",
+ urls = [
+ "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
+ "http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
+ ],
+ strip_prefix = "pcre-8.42",
+ build_file = clean_dep("//third_party:pcre.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "swig",
+ sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453",
+ urls = [
+ "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
+ "http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
+ "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
+ ],
+ strip_prefix = "swig-3.0.8",
+ build_file = clean_dep("//third_party:swig.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "curl",
+ sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
+ urls = [
+ "https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz",
+ "https://curl.haxx.se/download/curl-7.60.0.tar.gz",
+ ],
+ strip_prefix = "curl-7.60.0",
+ build_file = clean_dep("//third_party:curl.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "grpc",
+ urls = [
+ "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz",
+ "https://github.com/grpc/grpc/archive/v1.13.0.tar.gz",
+ ],
+ sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
+ strip_prefix = "grpc-1.13.0",
+ system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "linenoise",
+ sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
+ urls = [
+ "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
+ "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
+ ],
+ strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
+ build_file = clean_dep("//third_party:linenoise.BUILD"),
+ )
+
+ # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror.
+ # Switch to an official source of snapshots if/when possible.
+ tf_http_archive(
+ name = "llvm",
+ urls = [
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428.tar.gz",
+ ],
+ sha256 = "c6cbb21acd46e3e00faa8c379595ecffb99ef77622da17f29371db2bfad1d3d3",
+ strip_prefix = "llvm-7b3bfc8151f3a6bcd9642c49c1f86f66cc43a428",
+ build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "lmdb",
+ urls = [
+ "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
+ "https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
+ ],
+ sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
+ strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
+ build_file = clean_dep("//third_party:lmdb.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "jsoncpp_git",
+ urls = [
+ "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
+ "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
+ ],
+ sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
+ strip_prefix = "jsoncpp-1.8.4",
+ build_file = clean_dep("//third_party:jsoncpp.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "boringssl",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
+ "https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz",
+ ],
+ sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3",
+ strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778",
+ )
+
+ tf_http_archive(
+ name = "zlib_archive",
+ urls = [
+ "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
+ "https://zlib.net/zlib-1.2.11.tar.gz",
+ ],
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
+ build_file = clean_dep("//third_party:zlib.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "fft2d",
+ urls = [
+ "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
+ "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
+ ],
+ sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
+ build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "snappy",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz",
+ "https://github.com/google/snappy/archive/1.1.7.tar.gz",
+ ],
+ sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
+ strip_prefix = "snappy-1.1.7",
+ build_file = clean_dep("//third_party:snappy.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "nccl_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
+ "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
+ ],
+ sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
+ strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
+ build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "kafka",
+ urls = [
+ "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
+ "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
+ ],
+ sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c",
+ strip_prefix = "librdkafka-0.11.4",
+ build_file = clean_dep("//third_party:kafka/BUILD"),
+ patch_file = clean_dep("//third_party/kafka:config.patch"),
+ )
+
+ tf_http_archive(
+ name = "aws",
+ urls = [
+ "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
+ "https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
+ ],
+ sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
+ strip_prefix = "aws-sdk-cpp-1.3.15",
+ build_file = clean_dep("//third_party:aws.BUILD"),
+ )
+
+ java_import_external(
+ name = "junit",
+ jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a",
+ jar_urls = [
+ "https://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
+ "http://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar",
+ "http://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar",
+ ],
+ licenses = ["reciprocal"], # Common Public License Version 1.0
+ testonly_ = True,
+ deps = ["@org_hamcrest_core"],
+ )
+
+ java_import_external(
+ name = "org_hamcrest_core",
+ jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9",
+ jar_urls = [
+ "https://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
+ "http://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
+ "http://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar",
+ ],
+ licenses = ["notice"], # New BSD License
+ testonly_ = True,
+ )
+
+ tf_http_archive(
+ name = "jemalloc",
+ urls = [
+ "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
+ "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
+ ],
+ sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
+ strip_prefix = "jemalloc-4.4.0",
+ build_file = clean_dep("//third_party:jemalloc.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
+ )
+
+ java_import_external(
+ name = "com_google_testing_compile",
+ jar_sha256 = "edc180fdcd9f740240da1a7a45673f46f59c5578d8cd3fbc912161f74b5aebb8",
+ jar_urls = [
+ "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar",
+ "http://repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar",
+ ],
+ licenses = ["notice"], # New BSD License
+ testonly_ = True,
+ deps = ["@com_google_guava", "@com_google_truth"],
+ )
+
+ java_import_external(
+ name = "com_google_truth",
+ jar_sha256 = "032eddc69652b0a1f8d458f999b4a9534965c646b8b5de0eba48ee69407051df",
+ jar_urls = [
+ "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar",
+ "http://repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar",
+ ],
+ licenses = ["notice"], # Apache 2.0
+ testonly_ = True,
+ deps = ["@com_google_guava"],
+ )
+
+ java_import_external(
+ name = "org_checkerframework_qual",
+ jar_sha256 = "a17501717ef7c8dda4dba73ded50c0d7cde440fd721acfeacbf19786ceac1ed6",
+ jar_urls = [
+ "http://mirror.bazel.build/repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar",
+ "http://repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar",
+ ],
+ licenses = ["notice"], # Apache 2.0
+ )
+
+ java_import_external(
+ name = "com_squareup_javapoet",
+ jar_sha256 = "5bb5abdfe4366c15c0da3332c57d484e238bd48260d6f9d6acf2b08fdde1efea",
+ jar_urls = [
+ "http://mirror.bazel.build/repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar",
+ "http://repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar",
+ ],
+ licenses = ["notice"], # Apache 2.0
+ )
+
+ tf_http_archive(
+ name = "com_google_pprof",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
+ "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
+ ],
+ sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
+ strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
+ build_file = clean_dep("//third_party:pprof.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "cub_archive",
+ urls = [
+ "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
+ "https://github.com/NVlabs/cub/archive/1.8.0.zip",
+ ],
+ sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+ strip_prefix = "cub-1.8.0",
+ build_file = clean_dep("//third_party:cub.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "cython",
+ sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
+ urls = [
+ "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
+ "https://github.com/cython/cython/archive/0.28.4.tar.gz",
+ ],
+ strip_prefix = "cython-0.28.4",
+ build_file = clean_dep("//third_party:cython.BUILD"),
+ delete = ["BUILD.bazel"],
+ system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "bazel_toolchains",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
+ ],
+ strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
+ sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
+ )
+
+ tf_http_archive(
+ name = "arm_neon_2_x86_sse",
+ sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
+ strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",
+ urls = [
+ "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
+ "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
+ ],
+ build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "flatbuffers",
+ strip_prefix = "flatbuffers-1.9.0",
+ sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
+ "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
+ ],
+ build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:flatbuffers.BUILD"),
+ )
+
+ native.new_http_archive(
+ name = "double_conversion",
+ urls = [
+ "https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
+ ],
+ sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
+ strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
+ build_file = clean_dep("//third_party:double_conversion.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "tflite_mobilenet",
+ sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
+ ],
+ build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "tflite_mobilenet_ssd",
+ sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
+ ],
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
+ tf_http_archive(
+ name = "tflite_mobilenet_ssd_quant",
+ sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
+ ],
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
+
+ tf_http_archive(
+ name = "tflite_conv_actions_frozen",
+ sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
+ ],
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
+
+ tf_http_archive(
+ name = "tflite_smartreply",
+ sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
+ ],
+ build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "tflite_ovic_testdata",
+ sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
+ ],
+ build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
+ strip_prefix = "ovic",
+ )
+
+ tf_http_archive(
+ name = "build_bazel_rules_android",
+ sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
+ "https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
+ ],
+ strip_prefix = "rules_android-0.1.1",
+ )
+
+ tf_http_archive(
+ name = "ngraph",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ ],
+ sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
+ strip_prefix = "ngraph-0.5.0",
+ build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "nlohmann_json_lib",
+ urls = [
+ "https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
+ "https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
+ ],
+ sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
+ strip_prefix = "json-3.1.1",
+ build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "ngraph_tf",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+ "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+ ],
+ sha256 = "7919332cb15120101c3e05c1b969a5e029a6411581312583c8f80b6aaaa83072",
+ strip_prefix = "ngraph-tf-0.3.0-rc1",
+ build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
+ )
+
+ ##############################################################################
+ # BIND DEFINITIONS
+ #
+ # Please do not add bind() definitions unless we have no other choice.
+ # If that ends up being the case, please leave a comment explaining
+ # why we can't depend on the canonical build target.
+
+ # gRPC wants a cares dependency but its contents is not actually
+ # important since we have set GRPC_ARES=0 in tools/bazel.rc
+ native.bind(
+ name = "cares",
+ actual = "@grpc//third_party/nanopb:nanopb",
+ )
+
+ # Needed by Protobuf
+ native.bind(
+ name = "grpc_cpp_plugin",
+ actual = "@grpc//:grpc_cpp_plugin",
+ )
+ native.bind(
+ name = "grpc_python_plugin",
+ actual = "@grpc//:grpc_python_plugin",
+ )
+
+ native.bind(
+ name = "grpc_lib",
+ actual = "@grpc//:grpc++",
+ )
+
+ native.bind(
+ name = "grpc_lib_unsecure",
+ actual = "@grpc//:grpc++_unsecure",
+ )
+
+ # Needed by gRPC
+ native.bind(
+ name = "libssl",
+ actual = "@boringssl//:ssl",
+ )
+
+ # Needed by gRPC
+ native.bind(
+ name = "nanopb",
+ actual = "@grpc//third_party/nanopb:nanopb",
+ )
+
+ # Needed by gRPC
+ native.bind(
+ name = "protobuf",
+ actual = "@protobuf_archive//:protobuf",
+ )
+
+ # gRPC expects //external:protobuf_clib and //external:protobuf_compiler
+ # to point to Protobuf's compiler library.
+ native.bind(
+ name = "protobuf_clib",
+ actual = "@protobuf_archive//:protoc_lib",
+ )
+
+ # Needed by gRPC
+ native.bind(
+ name = "protobuf_headers",
+ actual = "@protobuf_archive//:protobuf_headers",
+ )
+
+ # Needed by Protobuf
+ native.bind(
+ name = "python_headers",
+ actual = clean_dep("//third_party/python_runtime:headers"),
+ )
+
+ # Needed by Protobuf
+ native.bind(
+ name = "six",
+ actual = "@six_archive//:six",
+ )
+
+ # Needed by gRPC
+ native.bind(
+ name = "zlib",
+ actual = "@zlib_archive//:zlib",
+ )